diff --git a/.flake8 b/.flake8
index 609fa2c03..41d8799c8 100644
--- a/.flake8
+++ b/.flake8
@@ -1,7 +1,7 @@
[flake8]
show-source=true
statistics=true
-max-line-length = 80
+max-line-length = 88
per-file-ignores =
# line too long
icefall/diagnostics.py: E501,
@@ -11,7 +11,8 @@ per-file-ignores =
egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
- egs/librispeech/ASR/conformer_ctc2/*py: E501,
+ egs/librispeech/ASR/conformer_ctc*/*py: E501,
+ egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
# invalid escape sequence (cause by tex formular), W605
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..5d65b98e9
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,3 @@
+# Migrate to 88 characters per line (see: https://github.com/lhotse-speech/lhotse/issues/890)
+107df3b115a58f1b68a6458c3f94a130004be34c
+d31db010371a4128856480382876acdc0d1739ed
diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
index bb7c7dfdc..0bec8c0c4 100755
--- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
@@ -15,5 +15,5 @@ mkdir -p data
cd data
[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
cd ..
-./local/compute_fbank_librispeech.py
+./local/compute_fbank_librispeech.py --dataset 'test-clean test-other'
ls -lh data/fbank/
diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
index e70a1848d..4c393f6be 100755
--- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
+++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
@@ -25,7 +25,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
new file mode 100755
index 000000000..c68ccc954
--- /dev/null
+++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
@@ -0,0 +1,122 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
+git lfs pull --include "exp/jit_trace.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+ ./conformer_ctc3/jit_pretrained.py \
+ --model-filename $repo/exp/jit_trace.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+log "Export to torchscript model"
+
+./conformer_ctc3/export.py \
+ --exp-dir $repo/exp \
+ --lang-dir $repo/data/lang_bpe_500 \
+ --jit-trace 1 \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.trace()"
+
+for m in ctc-decoding 1best; do
+ ./conformer_ctc3/jit_pretrained.py \
+ --model-filename $repo/exp/jit_trace.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+ ./conformer_ctc3/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p conformer_ctc3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh conformer_ctc3/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in ctc-decoding 1best; do
+ log "Decoding with $method"
+ ./conformer_ctc3/decode.py \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --exp-dir conformer_ctc3/exp/ \
+ --max-duration $max_duration \
+ --decoding-method $method \
+ --lm-dir data/lm
+ done
+
+ rm conformer_ctc3/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
new file mode 100755
index 000000000..4cd2c4bec
--- /dev/null
+++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
@@ -0,0 +1,191 @@
+#!/usr/bin/env bash
+#
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+abs_repo=$(realpath $repo)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
+ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
+popd
+
+log "Test exporting with torch.jit.trace()"
+
+./lstm_transducer_stateless2/export.py \
+ --exp-dir $repo/exp \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --jit-trace 1
+
+log "Decode with models exported by torch.jit.trace()"
+
+./lstm_transducer_stateless2/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./lstm_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./lstm_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+
+if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
+ lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+ log "Download pre-trained RNN-LM model from ${lm_repo_url}"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
+ lm_repo=$(basename $lm_repo_url)
+ pushd $lm_repo
+ git lfs pull --include "exp/pretrained.pt"
+ mv exp/pretrained.pt exp/epoch-88.pt
+ popd
+
+ mkdir -p lstm_transducer_stateless2/exp
+ ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh lstm_transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other with RNN LM"
+
+ ./lstm_transducer_stateless2/decode.py \
+ --use-averaged-model 0 \
+ --epoch 999 \
+ --avg 1 \
+ --exp-dir lstm_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search_lm_shallow_fusion \
+ --beam 4 \
+ --use-shallow-fusion 1 \
+ --lm-type rnn \
+ --lm-exp-dir $lm_repo/exp \
+ --lm-epoch 88 \
+ --lm-avg 1 \
+ --lm-scale 0.3 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1
+fi
+
+if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then
+ bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram
+ log "Download bi-gram LM from ${bigram_repo_url}"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
+ bigramlm_repo=$(basename $bigram_repo_url)
+ pushd $bigramlm_repo
+ git lfs pull --include "2gram.fst.txt"
+ cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/.
+ popd
+
+ lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+ log "Download pre-trained RNN-LM model from ${lm_repo_url}"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
+ lm_repo=$(basename $lm_repo_url)
+ pushd $lm_repo
+ git lfs pull --include "exp/pretrained.pt"
+ mv exp/pretrained.pt exp/epoch-88.pt
+ popd
+
+ mkdir -p lstm_transducer_stateless2/exp
+ ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh lstm_transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other"
+
+ ./lstm_transducer_stateless2/decode.py \
+ --use-averaged-model 0 \
+ --epoch 999 \
+ --avg 1 \
+ --exp-dir lstm_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search_LODR \
+ --beam 4 \
+ --use-shallow-fusion 1 \
+ --lm-type rnn \
+ --lm-exp-dir $lm_repo/exp \
+ --lm-scale 0.4 \
+ --lm-epoch 88 \
+ --rnn-lm-avg 1 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1 \
+ --tokens-ngram 2 \
+ --ngram-lm-scale -0.16
+fi
+
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
+ mkdir -p lstm_transducer_stateless2/exp
+ ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh lstm_transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./lstm_transducer_stateless2/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --max-duration $max_duration \
+ --exp-dir lstm_transducer_stateless2/exp
+ done
+
+ rm lstm_transducer_stateless2/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
deleted file mode 100755
index 6ce92d022..000000000
--- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env bash
-#
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-soxi $repo/test_wavs/*.wav
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
-ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
-popd
-
-log "Install ncnn and pnnx"
-
-# We are using a modified ncnn here. Will try to merge it to the official repo
-# of ncnn
-git clone https://github.com/csukuangfj/ncnn
-pushd ncnn
-git submodule init
-git submodule update python/pybind11
-python3 setup.py bdist_wheel
-ls -lh dist/
-pip install dist/*.whl
-cd tools/pnnx
-mkdir build
-cd build
-cmake ..
-make -j4 pnnx
-
-./src/pnnx || echo "pass"
-
-popd
-
-log "Test exporting to pnnx format"
-
-./lstm_transducer_stateless2/export.py \
- --exp-dir $repo/exp \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --epoch 99 \
- --avg 1 \
- --use-averaged-model 0 \
- --pnnx 1
-
-./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt
-./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt
-./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt
-
-./lstm_transducer_stateless2/ncnn-decode.py \
- --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
- --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
- --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
- --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
- --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
- --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
- --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
- $repo/test_wavs/1089-134686-0001.wav
-
-./lstm_transducer_stateless2/streaming-ncnn-decode.py \
- --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
- --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
- --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
- --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
- --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
- --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
- --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
- $repo/test_wavs/1089-134686-0001.wav
-
-
-
-log "Test exporting with torch.jit.trace()"
-
-./lstm_transducer_stateless2/export.py \
- --exp-dir $repo/exp \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --epoch 99 \
- --avg 1 \
- --use-averaged-model 0 \
- --jit-trace 1
-
-log "Decode with models exported by torch.jit.trace()"
-
-./lstm_transducer_stateless2/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
- --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
- --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Test exporting to ONNX"
-
-./lstm_transducer_stateless2/export.py \
- --exp-dir $repo/exp \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --epoch 99 \
- --avg 1 \
- --use-averaged-model 0 \
- --onnx 1
-
-log "Decode with ONNX models "
-
-./lstm_transducer_stateless2/streaming-onnx-decode.py \
- --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo//exp/encoder.onnx \
- --decoder-model-filename $repo/exp/decoder.onnx \
- --joiner-model-filename $repo/exp/joiner.onnx \
- --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
- --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
- $repo/test_wavs/1089-134686-0001.wav
-
-./lstm_transducer_stateless2/streaming-onnx-decode.py \
- --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo//exp/encoder.onnx \
- --decoder-model-filename $repo/exp/decoder.onnx \
- --joiner-model-filename $repo/exp/joiner.onnx \
- --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
- --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
- $repo/test_wavs/1221-135766-0001.wav
-
-./lstm_transducer_stateless2/streaming-onnx-decode.py \
- --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo//exp/encoder.onnx \
- --decoder-model-filename $repo/exp/decoder.onnx \
- --joiner-model-filename $repo/exp/joiner.onnx \
- --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
- --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
- $repo/test_wavs/1221-135766-0002.wav
-
-
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./lstm_transducer_stateless2/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./lstm_transducer_stateless2/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-
-if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then
- lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
- log "Download pre-trained RNN-LM model from ${lm_repo_url}"
- git clone $lm_repo_url
- lm_repo=$(basename $lm_repo_url)
- pushd $lm_repo
- git lfs pull --include "exp/pretrained.pt"
- cd exp
- ln -s pretrained.pt epoch-88.pt
- popd
-
- ./lstm_transducer_stateless2/decode.py \
- --use-averaged-model 0 \
- --epoch 99 \
- --avg 1 \
- --exp-dir $repo/exp \
- --lang-dir $repo/data/lang_bpe_500 \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --max-duration 600 \
- --decoding-method modified_beam_search_rnnlm_shallow_fusion \
- --beam 4 \
- --rnn-lm-scale 0.3 \
- --rnn-lm-exp-dir $lm_repo/exp \
- --rnn-lm-epoch 88 \
- --rnn-lm-avg 1 \
- --rnn-lm-num-layers 3 \
- --rnn-lm-tie-weights 1
-fi
-
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
- mkdir -p lstm_transducer_stateless2/exp
- ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh lstm_transducer_stateless2/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./lstm_transducer_stateless2/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir lstm_transducer_stateless2/exp
- done
-
- rm lstm_transducer_stateless2/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
index dafea56db..6792c7088 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
index c3d07dc0e..dbf678d72 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
@@ -23,7 +23,6 @@ popd
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
index 22de3b45d..b6d477afe 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
@@ -22,7 +22,6 @@ popd
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
index 880767443..efa4b53f0 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
@@ -27,14 +26,6 @@ ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
popd
-log "Test exporting to ONNX format"
-
-./pruned_transducer_stateless3/export.py \
- --exp-dir $repo/exp \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --epoch 99 \
- --avg 1 \
- --onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
@@ -51,30 +42,8 @@ log "Export to torchscript model"
--avg 1 \
--jit-trace 1
-ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt
-log "Decode with ONNX models"
-
-./pruned_transducer_stateless3/onnx_check.py \
- --jit-filename $repo/exp/cpu_jit.pt \
- --onnx-encoder-filename $repo/exp/encoder.onnx \
- --onnx-decoder-filename $repo/exp/decoder.onnx \
- --onnx-joiner-filename $repo/exp/joiner.onnx \
- --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
- --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
-
-./pruned_transducer_stateless3/onnx_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo/exp/encoder.onnx \
- --decoder-model-filename $repo/exp/decoder.onnx \
- --joiner-model-filename $repo/exp/joiner.onnx \
- --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
- --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
log "Decode with models exported by torch.jit.trace()"
./pruned_transducer_stateless3/jit_pretrained.py \
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
index c6a781318..511fe0c9e 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
index 75861bbc7..2bc179c86 100755
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
@@ -33,6 +32,7 @@ popd
log "Export to torchscript model"
./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \
+ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
new file mode 100755
index 000000000..192438353
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
@@ -0,0 +1,150 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "data/lm/G_4_gram.pt"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./pruned_transducer_stateless7_ctc/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+ --model-filename $repo/exp/cpu_jit.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_ctc/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_ctc/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless7_ctc/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless7_ctc/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless7_ctc/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless7_ctc/exp
+ done
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc/ctc_decode.py \
+ --epoch 999 \
+ --avg 1 \
+ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
+ --max-duration $max_duration \
+ --use-averaged-model 0 \
+ --decoding-method $m \
+ --hlg-scale 0.6 \
+ --lm-dir data/lm
+ done
+
+ rm pruned_transducer_stateless7_ctc/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
new file mode 100755
index 000000000..761eb72e2
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
@@ -0,0 +1,147 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/HLG.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./pruned_transducer_stateless7_ctc_bs/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
+ --model-filename $repo/exp/cpu_jit.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless7_ctc_bs/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless7_ctc_bs/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless7_ctc_bs/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless7_ctc_bs/exp
+ done
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \
+ --epoch 999 \
+ --avg 1 \
+ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
+ --max-duration $max_duration \
+ --use-averaged-model 0 \
+ --decoding-method $m \
+ --hlg-scale 0.6
+ done
+
+ rm pruned_transducer_stateless7_ctc_bs/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
new file mode 100755
index 000000000..e1e4e1f10
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
@@ -0,0 +1,148 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+git lfs pull --include "exp/encoder_jit_trace.pt"
+git lfs pull --include "exp/decoder_jit_trace.pt"
+git lfs pull --include "exp/joiner_jit_trace.pt"
+cd exp
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./pruned_transducer_stateless7_streaming/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless7_streaming/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+log "Export to torchscript model by torch.jit.trace()"
+./pruned_transducer_stateless7_streaming/jit_trace_export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ --epoch 99 \
+ --avg 1
+
+log "Decode with models exported by torch.jit.trace()"
+
+./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_streaming/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_streaming/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless7_streaming/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless7_streaming/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+ num_decode_stream=200
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "decoding with $method"
+
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --max-duration $max_duration \
+ --decode-chunk-len 32 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp
+ done
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless7_streaming/streaming_decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --decode-chunk-len 32 \
+ --num-decode-streams $num_decode_stream
+ --exp-dir pruned_transducer_stateless7_streaming/exp
+ done
+
+ rm pruned_transducer_stateless7_streaming/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
new file mode 100755
index 000000000..5d9485692
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
@@ -0,0 +1,115 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless8/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+log "Export to torchscript model"
+./pruned_transducer_stateless8/export.py \
+ --exp-dir $repo/exp \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model false \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./pruned_transducer_stateless8/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless8/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless8/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless8/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless8/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless8/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless8/exp
+ done
+
+ rm pruned_transducer_stateless8/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
index af37102d5..77cd59506 100755
--- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
+++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
index 5b8ed396b..b4aca1b6b 100755
--- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
+++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
new file mode 100755
index 000000000..a58b8ec56
--- /dev/null
+++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
@@ -0,0 +1,102 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+git lfs pull --include "data/lang_bpe_500/3gram.pt"
+git lfs pull --include "data/lang_bpe_500/4gram.pt"
+git lfs pull --include "data/lang_bpe_500/L.pt"
+git lfs pull --include "data/lang_bpe_500/LG.pt"
+git lfs pull --include "data/lang_bpe_500/Linv.pt"
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/cpu_jit.pt"
+git lfs pull --include "exp/pretrained.pt"
+ln -s pretrained.pt epoch-99.pt
+ls -lh *.pt
+popd
+
+log "Export to torchscript model"
+./zipformer_mmi/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ls -lh $repo/exp/*.pt
+
+log "Decode with models exported by torch.jit.script()"
+
+./zipformer_mmi/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ --lang-dir $repo/data/lang_bpe_500 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+ log "$method"
+
+ ./zipformer_mmi/pretrained.py \
+ --method $method \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_bpe_500 \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p zipformer_mmi/exp
+ ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh zipformer_mmi/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+ log "Decoding with $method"
+
+ ./zipformer_mmi/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --nbest-scale 1.2 \
+ --hp-scale 1.0 \
+ --max-duration $max_duration \
+ --lang-dir $repo/data/lang_bpe_500 \
+ --exp-dir zipformer_mmi/exp
+ done
+
+ rm zipformer_mmi/exp/*.pt
+fi
diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh
index 96c320616..125d1f3b1 100755
--- a/.github/scripts/run-pre-trained-conformer-ctc.sh
+++ b/.github/scripts/run-pre-trained-conformer-ctc.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.flac
ls -lh $repo/test_wavs/*.flac
log "CTC decoding"
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
index 209d4814f..89115e88d 100755
--- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
+++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
index 34ff76fe4..85e2c89e6 100755
--- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
+++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
index 75650c2d3..0644d9be0 100755
--- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
+++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
index bcc2d74cb..79fb64311 100755
--- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
+++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh
index d3e40315a..41456f11b 100755
--- a/.github/scripts/run-pre-trained-transducer-stateless.sh
+++ b/.github/scripts/run-pre-trained-transducer-stateless.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
for sym in 1 2 3; do
diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh
index cfa006776..1331c966c 100755
--- a/.github/scripts/run-pre-trained-transducer.sh
+++ b/.github/scripts/run-pre-trained-transducer.sh
@@ -19,7 +19,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
log "Beam search decoding"
diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
index 2d237dcf2..90097c752 100755
--- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
+++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
@@ -20,7 +20,6 @@ repo=$(basename $repo_url)
log "Display test files"
tree $repo/
-soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh
new file mode 100755
index 000000000..52491d2ea
--- /dev/null
+++ b/.github/scripts/test-ncnn-export.sh
@@ -0,0 +1,234 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+pushd egs/librispeech/ASR
+
+log "Install ncnn and pnnx"
+
+# We are using a modified ncnn here. Will try to merge it to the official repo
+# of ncnn
+git clone https://github.com/csukuangfj/ncnn
+pushd ncnn
+git submodule init
+git submodule update python/pybind11
+python3 setup.py bdist_wheel
+ls -lh dist/
+pip install dist/*.whl
+cd tools/pnnx
+mkdir build
+cd build
+
+echo "which python3"
+
+which python3
+#/opt/hostedtoolcache/Python/3.8.16/x64/bin/python3
+
+cmake -D Python3_EXECUTABLE=$(which python3) ..
+make -j4 pnnx
+
+./src/pnnx || echo "pass"
+
+popd
+
+export PATH=$PWD/ncnn/tools/pnnx/build/src:$PATH
+
+log "=========================================================================="
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
+
+cd exp
+ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.trace()"
+
+./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+ --exp-dir $repo/exp \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32
+
+pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+python3 ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
+
+cd exp
+ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.trace()"
+
+./lstm_transducer_stateless2/export-for-ncnn.py \
+ --exp-dir $repo/exp \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0
+
+pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
+
+python3 ./lstm_transducer_stateless2/ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --exp-dir $repo/exp \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ \
+ --decode-chunk-len 32 \
+ --num-encoder-layers "2,4,3,2,4" \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --nhead "8,8,8,8,8" \
+ --encoder-dims "384,384,384,384,384" \
+ --attention-dims "192,192,192,192,192" \
+ --encoder-unmasked-dims "256,256,256,256,256" \
+ --zipformer-downsampling-factors "1,2,4,8,2" \
+ --cnn-module-kernels "31,31,31,31,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512
+
+pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/1089-134686-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_char_bpe/L.pt"
+git lfs pull --include "data/lang_char_bpe/L_disambig.pt"
+git lfs pull --include "data/lang_char_bpe/Linv.pt"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
+ --lang-dir $repo/data/lang_char_bpe \
+ --exp-dir $repo/exp \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --decode-chunk-len 32 \
+ --num-encoder-layers "2,4,3,2,4" \
+ --feedforward-dims "1024,1024,1536,1536,1024" \
+ --nhead "8,8,8,8,8" \
+ --encoder-dims "384,384,384,384,384" \
+ --attention-dims "192,192,192,192,192" \
+ --encoder-unmasked-dims "256,256,256,256,256" \
+ --zipformer-downsampling-factors "1,2,4,8,2" \
+ --cnn-module-kernels "31,31,31,31,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512
+
+pnnx $repo/exp/encoder_jit_trace-pnnx.pt
+pnnx $repo/exp/decoder_jit_trace-pnnx.pt
+pnnx $repo/exp/joiner_jit_trace-pnnx.pt
+
+python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
+ --tokens $repo/data/lang_char_bpe/tokens.txt \
+ --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ $repo/test_wavs/0.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh
new file mode 100755
index 000000000..39467c44a
--- /dev/null
+++ b/.github/scripts/test-onnx-export.sh
@@ -0,0 +1,351 @@
+#!/usr/bin/env bash
+
+set -e
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+
+
+log "=========================================================================="
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.trace()"
+
+./pruned_transducer_stateless7_streaming/jit_trace_export.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --decode-chunk-len 32 \
+ --exp-dir $repo/exp/
+
+log "Test exporting to ONNX format"
+
+./pruned_transducer_stateless7_streaming/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --decode-chunk-len 32 \
+ --exp-dir $repo/exp/
+
+ls -lh $repo/exp
+
+log "Run onnx_check.py"
+
+./pruned_transducer_stateless7_streaming/onnx_check.py \
+ --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
+ --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
+ --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
+
+log "Run onnx_pretrained.py"
+
+./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-iter-1224000-avg-14.pt"
+
+cd exp
+ln -s pretrained-iter-1224000-avg-14.pt epoch-9999.pt
+popd
+
+log "Export via torch.jit.script()"
+
+./pruned_transducer_stateless3/export.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/ \
+ --jit 1
+
+log "Test exporting to ONNX format"
+
+./pruned_transducer_stateless3/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/
+
+ls -lh $repo/exp
+
+log "Run onnx_check.py"
+
+./pruned_transducer_stateless3/onnx_check.py \
+ --jit-filename $repo/exp/cpu_jit.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx
+
+log "Run onnx_pretrained.py"
+
+./pruned_transducer_stateless3/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+
+log "=========================================================================="
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-epoch-39-avg-7.pt"
+
+cd exp
+ln -s pretrained-epoch-39-avg-7.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.script()"
+
+./pruned_transducer_stateless5/export.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --exp-dir $repo/exp \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --jit 1
+
+log "Test exporting to ONNX format"
+
+./pruned_transducer_stateless5/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --exp-dir $repo/exp \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+
+ls -lh $repo/exp
+
+log "Run onnx_check.py"
+
+./pruned_transducer_stateless5/onnx_check.py \
+ --jit-filename $repo/exp/cpu_jit.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
+
+log "Run onnx_pretrained.py"
+
+./pruned_transducer_stateless5/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.script()"
+
+./pruned_transducer_stateless7/export.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --jit 1
+
+log "Test exporting to ONNX format"
+
+./pruned_transducer_stateless7/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --feedforward-dims "1024,1024,2048,2048,1024"
+
+ls -lh $repo/exp
+
+log "Run onnx_check.py"
+
+./pruned_transducer_stateless7/onnx_check.py \
+ --jit-filename $repo/exp/cpu_jit.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
+
+log "Run onnx_pretrained.py"
+
+./pruned_transducer_stateless7/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+log "=========================================================================="
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
+
+cd exp
+ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt
+popd
+
+log "Test exporting to ONNX format"
+
+./conv_emformer_transducer_stateless2/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32
+
+log "Run onnx_pretrained.py"
+
+./conv_emformer_transducer_stateless2/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1221-135766-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
+
+log "=========================================================================="
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
+
+cd exp
+ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
+popd
+
+log "Export via torch.jit.trace()"
+
+./lstm_transducer_stateless2/export.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp/ \
+ --jit-trace 1
+
+log "Test exporting to ONNX format"
+
+./lstm_transducer_stateless2/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp
+
+ls -lh $repo/exp
+
+log "Run onnx_check.py"
+
+./lstm_transducer_stateless2/onnx_check.py \
+ --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \
+ --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \
+ --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
+
+log "Run onnx_pretrained.py"
+
+./lstm_transducer_stateless2/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1221-135766-0001.wav
+
+rm -rf $repo
+log "--------------------------------------------------------------------------"
diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
index dd0969f51..d7fe2c964 100644
--- a/.github/workflows/build-doc.yml
+++ b/.github/workflows/build-doc.yml
@@ -26,6 +26,10 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: build_doc-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
build-doc:
if: github.event.label.name == 'doc' || github.event_name == 'push'
diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml
index e46b01a08..f5ba73195 100644
--- a/.github/workflows/run-aishell-2022-06-20.yml
+++ b/.github/workflows/run-aishell-2022-06-20.yml
@@ -34,6 +34,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_aishell_2022_06_20-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_aishell_2022_06_20:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -61,7 +65,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -83,7 +87,7 @@ jobs:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml
index c631927fa..c7b9cc79d 100644
--- a/.github/workflows/run-gigaspeech-2022-05-13.yml
+++ b/.github/workflows/run-gigaspeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_gigaspeech_2022_05_13-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_gigaspeech_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml
index 5df710006..9c7cd1228 100644
--- a/.github/workflows/run-librispeech-2022-03-12.yml
+++ b/.github/workflows/run-librispeech-2022-03-12.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_2022_03_12-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
index 24c062442..78c9e759f 100644
--- a/.github/workflows/run-librispeech-2022-04-29.yml
+++ b/.github/workflows/run-librispeech-2022-04-29.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_2022_04_29-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_2022_04_29:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml
index 29215ec25..04799bf52 100644
--- a/.github/workflows/run-librispeech-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-2022-05-13.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_2022_05_13-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
index 3b98b500e..6dfc23920 100644
--- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
+++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_2022_11_11_zipformer-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_2022_11_11_zipformer:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
new file mode 100644
index 000000000..0544e68b3
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
@@ -0,0 +1,159 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-11-14-stateless8
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_librispeech_2022_11_14_zipformer_stateless8:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
+
+ - name: Display decoding results for librispeech pruned_transducer_stateless8
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless8/exp
+
+ cd pruned_transducer_stateless8
+ echo "results for pruned_transducer_stateless8"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech pruned_transducer_stateless8
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14
+ path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
new file mode 100644
index 000000000..62e1f2a01
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
@@ -0,0 +1,163 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-01-stateless7-ctc
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_11_11_zipformer:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+
+ - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless7_ctc/exp
+
+ cd pruned_transducer_stateless7_ctc
+ echo "results for pruned_transducer_stateless7_ctc"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===ctc decoding==="
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===1best==="
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
+ path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
new file mode 100644
index 000000000..7dc33aaa9
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
@@ -0,0 +1,167 @@
+# Copyright 2022 Zengwei Yao
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-08-zipformer-mmi
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: run_librispeech_2022_12_08_zipformer-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_librispeech_2022_12_08_zipformer:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
+
+ - name: Display decoding results for librispeech zipformer-mmi
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./zipformer-mmi/exp
+
+ cd zipformer-mmi
+ echo "results for zipformer-mmi"
+ echo "===1best==="
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===nbest==="
+ find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===nbest-rescoring-LG==="
+ find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===nbest-rescoring-3-gram==="
+ find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===nbest-rescoring-4-gram==="
+ find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech zipformer-mmi
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
+ path: egs/librispeech/ASR/zipformer_mmi/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml
new file mode 100644
index 000000000..de55847ad
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml
@@ -0,0 +1,163 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-15-stateless7-ctc-bs
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_12_15_zipformer_ctc_bs:
+ if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
+
+ - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless7_ctc_bs/exp
+
+ cd pruned_transducer_stateless7_ctc_bs
+ echo "results for pruned_transducer_stateless7_ctc_bs"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===ctc decoding==="
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===1best==="
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15
+ path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml
new file mode 100644
index 000000000..feb5c6fd0
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml
@@ -0,0 +1,172 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-12-29-stateless7-streaming
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_librispeech_2022_12_29_zipformer_streaming:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
+
+ - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless7_streaming/exp
+
+ cd pruned_transducer_stateless7_streaming
+ echo "results for pruned_transducer_stateless7_streaming"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===streaming greedy search==="
+ find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===streaming fast_beam_search==="
+ find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===streaming modified beam search==="
+ find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+
+ - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29
+ path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
new file mode 100644
index 000000000..c95ed8b9a
--- /dev/null
+++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
@@ -0,0 +1,155 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-conformer-ctc3-2022-11-28
+# zipformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_librispeech_2022_11_28_conformer_ctc3:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+
+ - name: Display decoding results for librispeech conformer_ctc3
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./conformer_ctc3/exp
+
+ cd conformer_ctc3
+ echo "results for conformer_ctc3"
+ echo "===ctc-decoding==="
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===1best==="
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech conformer_ctc3
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
+ path: egs/librispeech/ASR/conformer_ctc3/exp/
diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
index a90841fb6..e14d4e92f 100644
--- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
@@ -16,9 +16,13 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_lstm_transducer_stateless2_2022_09_03:
- if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
+ if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -43,7 +47,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -102,12 +106,12 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
+ .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh
- name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule'
@@ -135,13 +139,25 @@ jobs:
cd egs/librispeech/ASR
tree lstm_transducer_stateless2/exp
cd lstm_transducer_stateless2/exp
- echo "===modified_beam_search_rnnlm_shallow_fusion==="
- find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+ echo "===modified_beam_search_lm_shallow_fusion==="
+ echo "===Using RNNLM==="
+ find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search_lm_shallow_fusion -name "log-*rnn*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Display decoding results for lstm_transducer_stateless2
+ if: github.event.label.name == 'LODR'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR
+ tree lstm_transducer_stateless2/exp
+ cd lstm_transducer_stateless2/exp
+ echo "===modified_beam_search_rnnlm_LODR==="
+ find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion'
+ if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
index 66a2c240b..73d91fcd4 100644
--- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
+++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
@@ -33,9 +33,13 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
- if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
index 55428861c..8a690393e 100644
--- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
+++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_streaming_2022_06_26-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_streaming_2022_06_26:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
index f520405e1..217dbdfa1 100644
--- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
+++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
@@ -33,6 +33,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_librispeech_2022_04_19-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_librispeech_2022_04_19:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -60,7 +64,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -119,7 +123,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml
index 9bc6a481f..4e8e7b8db 100644
--- a/.github/workflows/run-pretrained-conformer-ctc.yml
+++ b/.github/workflows/run-pretrained-conformer-ctc.yml
@@ -23,6 +23,10 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: run_pre_trained_conformer_ctc-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_conformer_ctc:
if: github.event.label.name == 'ready' || github.event_name == 'push'
@@ -50,7 +54,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -69,7 +73,7 @@ jobs:
- name: Inference with pre-trained model
shell: bash
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
index 7a0f30b0f..ddde4f1d6 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
@@ -32,6 +32,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -59,7 +63,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -118,7 +122,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
index 797f3fe50..00ea97b2a 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
@@ -32,6 +32,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -59,7 +63,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -118,7 +122,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
index 29e665881..b3cfc9efd 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
@@ -23,6 +23,10 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer_stateless_modified_2_aishell:
if: github.event.label.name == 'ready' || github.event_name == 'push'
@@ -50,7 +54,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -69,7 +73,7 @@ jobs:
- name: Inference with pre-trained model
shell: bash
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
index 6193f28e7..ab598541d 100644
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
@@ -23,6 +23,10 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer_stateless_modified_aishell:
if: github.event.label.name == 'ready' || github.event_name == 'push'
@@ -50,7 +54,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -69,7 +73,7 @@ jobs:
- name: Inference with pre-trained model
shell: bash
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
index 32208076c..d663d49dd 100644
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless.yml
@@ -32,6 +32,10 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+concurrency:
+ group: run_pre_trained_transducer_stateless-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer_stateless:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
@@ -59,7 +63,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -118,7 +122,7 @@ jobs:
ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
ls -lh egs/librispeech/ASR/data/*
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml
index 965d0f655..9cb9d3b59 100644
--- a/.github/workflows/run-pretrained-transducer.yml
+++ b/.github/workflows/run-pretrained-transducer.yml
@@ -23,6 +23,10 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: run_pre_trained_transducer-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
run_pre_trained_transducer:
if: github.event.label.name == 'ready' || github.event_name == 'push'
@@ -50,7 +54,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -69,7 +73,7 @@ jobs:
- name: Inference with pre-trained model
shell: bash
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml
new file mode 100644
index 000000000..f8d9c02c5
--- /dev/null
+++ b/.github/workflows/run-ptb-rnn-lm.yml
@@ -0,0 +1,71 @@
+name: run-ptb-rnn-lm-training
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: run_ptb_rnn_lm_training-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ run_ptb_rnn_lm_training:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.8"]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Prepare data
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ cd egs/ptb/LM
+ ./prepare.sh
+
+ - name: Run training
+ shell: bash
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ cd egs/ptb/LM
+ ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2
+
+ - name: Upload pretrained models
+ uses: actions/upload-artifact@v2
+ if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule'
+ with:
+ name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb
+ path: egs/ptb/LM/my-rnnlm-exp/
diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
index d96a3bfe6..14fb96ec8 100644
--- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
+++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml
@@ -23,8 +23,12 @@ on:
pull_request:
types: [labeled]
+concurrency:
+ group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
- run_librispeech_pruned_transducer_stateless3_2022_05_13:
+ run_wenetspeech_pruned_transducer_stateless2:
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech'
runs-on: ${{ matrix.os }}
strategy:
@@ -50,7 +54,7 @@ jobs:
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Cache kaldifeat
id: my-cache
@@ -72,7 +76,7 @@ jobs:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
- sudo apt-get -qq install git-lfs tree sox
+ sudo apt-get -qq install git-lfs tree
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index ce77c47df..83a1d5462 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -21,17 +21,21 @@ on:
branches:
- master
pull_request:
- types: [labeled]
+ branches:
+ - master
+
+concurrency:
+ group: run-yesno-recipe-${{ github.ref }}
+ cancel-in-progress: true
jobs:
run-yesno-recipe:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
# os: [ubuntu-18.04, macos-10.15]
# TODO: enable macOS for CPU testing
- os: [ubuntu-18.04]
+ os: [ubuntu-latest]
python-version: [3.8]
fail-fast: false
@@ -61,9 +65,9 @@ jobs:
- name: Install Python dependencies
run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
- name: Run yesno recipe
shell: bash
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 90459bc1c..fc1dcbfd4 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -24,6 +24,10 @@ on:
branches:
- master
+concurrency:
+ group: style_check-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
style_check:
runs-on: ${{ matrix.os }}
@@ -45,17 +49,18 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
- # See https://github.com/psf/black/issues/2964
- # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
+ python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
+ # Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8
shell: bash
working-directory: ${{github.workspace}}
run: |
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --show-source --statistics
- flake8 .
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \
+ --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503
- name: Run black
shell: bash
diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml
new file mode 100644
index 000000000..cdea54854
--- /dev/null
+++ b/.github/workflows/test-ncnn-export.yml
@@ -0,0 +1,75 @@
+name: test-ncnn-export
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: test_ncnn_export-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ test_ncnn_export:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Test ncnn export
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/test-ncnn-export.sh
diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml
new file mode 100644
index 000000000..3dc4261ab
--- /dev/null
+++ b/.github/workflows/test-onnx-export.yml
@@ -0,0 +1,75 @@
+name: test-onnx-export
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: test_onnx_export-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ test_onnx_export:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2022-09-25
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Test ONNX export
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/test-onnx-export.sh
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 04fc0265f..079772e97 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -21,26 +21,23 @@ on:
branches:
- master
pull_request:
- types: [labeled]
+ branches:
+ - master
+
+concurrency:
+ group: test-${{ github.ref }}
+ cancel-in-progress: true
jobs:
test:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
- # os: [ubuntu-18.04, macos-10.15]
- # disable macOS test for now.
- os: [ubuntu-18.04]
- python-version: [3.7, 3.8]
- torch: ["1.8.0", "1.11.0"]
- torchaudio: ["0.8.0", "0.11.0"]
- k2-version: ["1.15.1.dev20220427"]
- exclude:
- - torch: "1.8.0"
- torchaudio: "0.11.0"
- - torch: "1.11.0"
- torchaudio: "0.8.0"
+ os: [ubuntu-latest]
+ python-version: ["3.8"]
+ torch: ["1.10.0"]
+ torchaudio: ["0.10.0"]
+ k2-version: ["1.23.2.dev20221201"]
fail-fast: false
@@ -59,7 +56,7 @@ jobs:
run: |
sudo apt update
sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
- sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
+ sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all
- name: Install Python dependencies
run: |
@@ -67,21 +64,16 @@ jobs:
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
- pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- else
- pip install torchaudio==${{ matrix.torchaudio }}
- fi
+ pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements
pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
pip install kaldifst
pip install onnxruntime
-
pip install -r requirements.txt
- name: Install graphviz
@@ -121,19 +113,20 @@ jobs:
cd ../pruned_transducer_stateless4
pytest -v -s
+ cd ../pruned_transducer_stateless7
+ pytest -v -s
+
cd ../transducer_stateless
pytest -v -s
- if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
- cd ../transducer
- pytest -v -s
+ # cd ../transducer
+ # pytest -v -s
- cd ../transducer_stateless2
- pytest -v -s
+ cd ../transducer_stateless2
+ pytest -v -s
- cd ../transducer_lstm
- pytest -v -s
- fi
+ cd ../transducer_lstm
+ pytest -v -s
- name: Run tests
if: startsWith(matrix.os, 'macos')
@@ -164,13 +157,11 @@ jobs:
cd ../transducer_stateless
pytest -v -s
- if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
- cd ../transducer
- pytest -v -s
+ # cd ../transducer
+ # pytest -v -s
- cd ../transducer_stateless2
- pytest -v -s
+ cd ../transducer_stateless2
+ pytest -v -s
- cd ../transducer_lstm
- pytest -v -s
- fi
+ cd ../transducer_lstm
+ pytest -v -s
diff --git a/.gitignore b/.gitignore
index 406deff6a..8af05d884 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,5 +11,26 @@ log
*.bak
*-bak
*bak.py
+
+# Ignore Mac system files
+.DS_store
+
+# Ignore node_modules folder
+node_modules
+
+# ignore .nfs
+
+.nfs*
+
+# Ignore all text files
+*.txt
+
+# Ignore files related to API keys
+.env
+
+# Ignore SASS config files
+.sass-cache
+
*.param
*.bin
+.DS_Store
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 446ba0fe7..5cb213327 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,26 +1,38 @@
repos:
- repo: https://github.com/psf/black
- rev: 21.6b0
+ rev: 22.3.0
hooks:
- id: black
- args: [--line-length=80]
- additional_dependencies: ['click==8.0.1']
+ args: ["--line-length=88"]
+ additional_dependencies: ['click==8.1.0']
exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8
- rev: 3.9.2
+ rev: 5.0.4
hooks:
- id: flake8
- args: [--max-line-length=80]
+ args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
+
+ # What are we ignoring here?
+ # E203: whitespace before ':'
+ # E266: too many leading '#' for block comment
+ # E501: line too long
+ # F401: module imported but unused
+ # E402: module level import not at top of file
+ # F403: 'from module import *' used; unable to detect undefined names
+ # F841: local variable is assigned to but never used
+ # W503: line break before binary operator
+ # In addition, the default ignore list is:
+ # E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort
- rev: 5.9.2
+ rev: 5.10.1
hooks:
- id: isort
- args: [--profile=black, --line-length=80]
+ args: ["--profile=black"]
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.0.1
+ rev: v4.2.0
hooks:
- id: check-executables-have-shebangs
- id: end-of-file-fixer
diff --git a/LICENSE b/LICENSE
index ee06cfc77..d64569567 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,13 +1,4 @@
- Legal Notices
-
- NOTE (this is not from the Apache License): The copyright model is that
- authors (or their employers, if noted in individual files) own their
- individual contributions. The authors' contributions can be discerned
- from the git history.
-
- -------------------------------------------------------------------------
-
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
diff --git a/docker/README.md b/docker/README.md
index 6f2314e96..c14b9bf75 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,7 +2,7 @@
2 sets of configuration are provided - (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8, and (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8.
-If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
+If your NVIDIA driver supports CUDA Version: 11.3, please go for case (a) Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8.
Otherwise, since the older PyTorch images are not updated with the [apt-key rotation by NVIDIA](https://developer.nvidia.com/blog/updating-the-cuda-linux-gpg-repository-key), you have to go for case (b) Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8. Ensure that your NVDIA driver supports at least CUDA 11.0.
@@ -10,7 +10,7 @@ You can check the highest CUDA version within your NVIDIA driver's support with
```bash
$ nvidia-smi
-Tue Sep 20 00:26:13 2022
+Tue Sep 20 00:26:13 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03 Driver Version: 450.119.03 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
@@ -26,7 +26,7 @@ Tue Sep 20 00:26:13 2022
| 41% 30C P8 11W / 280W | 6MiB / 24220MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
-
+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
@@ -40,15 +40,15 @@ Tue Sep 20 00:26:13 2022
```
## Building images locally
-If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
-For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
+If your environment requires a proxy to access the Internet, remember to add those information into the Dockerfile directly.
+For most cases, you can uncomment these lines in the Dockerfile and add in your proxy details.
```dockerfile
ENV http_proxy=http://aaa.bb.cc.net:8080 \
https_proxy=http://aaa.bb.cc.net:8080
```
-Then, proceed with these commands.
+Then, proceed with these commands.
### If you are case (a), i.e. your NVIDIA driver supports CUDA version >= 11.3:
@@ -72,11 +72,11 @@ docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all icefall
```
### Tips:
-1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
+1. Since your data and models most probably won't be in the docker, you must use the -v flag to access the host machine. Do this by specifying `-v {/path/in/host/machine}:{/path/in/docker}`.
2. Also, if your environment requires a proxy, this would be a good time to add it in too: `-e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080`.
-Overall, your docker run command should look like this.
+Overall, your docker run command should look like this.
```bash
docker run -it --runtime=nvidia --shm-size=2gb --name=icefall --gpus all -v {/path/in/host/machine}:{/path/in/docker} -e http_proxy=http://aaa.bb.cc.net:8080 -e https_proxy=http://aaa.bb.cc.net:8080 icefall/pytorch1.12.1
@@ -86,9 +86,9 @@ You can explore more docker run options [here](https://docs.docker.com/engine/re
### Linking to icefall in your host machine
-If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
+If you already have icefall downloaded onto your host machine, you can use that repository instead so that changes in your code are visible inside and outside of the container.
-Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
+Note: Remember to set the -v flag above during the first run of the container, as that is the only way for your container to access your host machine.
Warning: Check that the icefall in your host machine is visible from within your container before proceeding to the commands below.
Use these commands once you are inside the container.
@@ -103,7 +103,7 @@ ln -s {/path/in/docker/to/icefall} /workspace/icefall
docker exec -it icefall /bin/bash
```
-## Restarting a killed container that has been run before.
+## Restarting a killed container that has been run before.
```bash
docker start -ai icefall
```
@@ -111,4 +111,4 @@ docker start -ai icefall
## Sample usage of the CPU based images:
```bash
docker run -it icefall /bin/bash
-```
+```
diff --git a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
index 524303fb8..ff9e40604 100644
--- a/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.12.1-cuda11.3-cudnn8/Dockerfile
@@ -1,7 +1,7 @@
FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-# https_proxy=http://aaa.bbb.cc.net:8080
+# https_proxy=http://aaa.bbb.cc.net:8080
# install normal source
RUN apt-get update && \
@@ -38,10 +38,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
-
-# flac
+
+# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
- cd /opt && \
+ cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
@@ -49,11 +49,11 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
- cd -
+ cd -
RUN conda install -y -c pytorch torchaudio=0.12 && \
pip install graphviz
-
+
#install k2 from source
RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
@@ -68,6 +68,7 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install -r requirements.txt
+RUN pip install kaldifeat
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
-WORKDIR /workspace/icefall
\ No newline at end of file
+WORKDIR /workspace/icefall
diff --git a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
index 17a8215f9..5c7423fa5 100644
--- a/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
+++ b/docker/Ubuntu18.04-pytorch1.7.1-cuda11.0-cudnn8/Dockerfile
@@ -1,12 +1,12 @@
FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel
# ENV http_proxy=http://aaa.bbb.cc.net:8080 \
-# https_proxy=http://aaa.bbb.cc.net:8080
+# https_proxy=http://aaa.bbb.cc.net:8080
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list && \
apt-key del 7fa2af80
-
+
# install normal source
RUN apt-get update && \
apt-get install -y --no-install-recommends \
@@ -36,7 +36,7 @@ RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu18
curl -fsSL https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list && \
- rm -rf /var/lib/apt/lists/* && \
+ rm -rf /var/lib/apt/lists/* && \
mv /opt/conda/lib/libcufft.so.10 /opt/libcufft.so.10.bak && \
mv /opt/conda/lib/libcurand.so.10 /opt/libcurand.so.10.bak && \
mv /opt/conda/lib/libcublas.so.11 /opt/libcublas.so.11.bak && \
@@ -56,10 +56,10 @@ RUN wget -P /opt https://cmake.org/files/v3.18/cmake-3.18.0.tar.gz && \
rm -rf cmake-3.18.0.tar.gz && \
find /opt/cmake-3.18.0 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
cd -
-
-# flac
+
+# flac
RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz && \
- cd /opt && \
+ cd /opt && \
xz -d flac-1.3.2.tar.xz && \
tar -xvf flac-1.3.2.tar && \
cd flac-1.3.2 && \
@@ -67,7 +67,7 @@ RUN wget -P /opt https://downloads.xiph.org/releases/flac/flac-1.3.2.tar.xz &&
make && make install && \
rm -rf flac-1.3.2.tar && \
find /opt/flac-1.3.2 -type f \( -name "*.o" -o -name "*.la" -o -name "*.a" \) -exec rm {} \; && \
- cd -
+ cd -
RUN conda install -y -c pytorch torchaudio=0.7.1 && \
pip install graphviz
@@ -79,7 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -
# install lhotse
-RUN pip install git+https://github.com/lhotse-speech/lhotse
+RUN pip install git+https://github.com/lhotse-speech/lhotse
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
@@ -88,4 +88,3 @@ RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall
-
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 000000000..3abb38f8b
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,24 @@
+
+## Usage
+
+```bash
+cd /path/to/icefall/docs
+pip install -r requirements.txt
+make clean
+make html
+cd build/html
+python3 -m http.server 8000
+```
+
+It prints:
+
+```
+Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
+```
+
+Open your browser and go to to view the generated
+documentation.
+
+Done!
+
+**Hint**: You can change the port number when starting the server.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 221d9d734..6901dec02 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -78,3 +78,15 @@ html_context = {
}
todo_include_todos = True
+
+rst_epilog = """
+.. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn
+.. _sherpa-onnx: https://github.com/k2-fsa/sherpa-onnx
+.. _icefall: https://github.com/k2-fsa/icefall
+.. _git-lfs: https://git-lfs.com/
+.. _ncnn: https://github.com/tencent/ncnn
+.. _LibriSpeech: https://www.openslr.org/12
+.. _musan: http://www.openslr.org/17/
+.. _ONNX: https://github.com/onnx/onnx
+.. _onnxruntime: https://github.com/microsoft/onnxruntime
+"""
diff --git a/docs/source/contributing/code-style.rst b/docs/source/contributing/code-style.rst
index 7d61a3ba1..3baaaeec2 100644
--- a/docs/source/contributing/code-style.rst
+++ b/docs/source/contributing/code-style.rst
@@ -11,9 +11,9 @@ We use the following tools to make the code style to be as consistent as possibl
The following versions of the above tools are used:
- - ``black == 12.6b0``
- - ``flake8 == 3.9.2``
- - ``isort == 5.9.2``
+ - ``black == 22.3.0``
+ - ``flake8 == 5.0.4``
+ - ``isort == 5.10.1``
After running the following commands:
@@ -54,10 +54,17 @@ it should succeed this time:
If you want to check the style of your code before ``git commit``, you
can do the following:
+ .. code-block:: bash
+
+ $ pre-commit install
+ $ pre-commit run
+
+Or without installing the pre-commit hooks:
+
.. code-block:: bash
$ cd icefall
- $ pip install black==21.6b0 flake8==3.9.2 isort==5.9.2
+ $ pip install black==22.3.0 flake8==5.0.4 isort==5.10.1
$ black --check your_changed_file.py
$ black your_changed_file.py # modify it in-place
$
diff --git a/docs/source/faqs.rst b/docs/source/faqs.rst
new file mode 100644
index 000000000..72b0302d7
--- /dev/null
+++ b/docs/source/faqs.rst
@@ -0,0 +1,107 @@
+Frequently Asked Questions (FAQs)
+=================================
+
+In this section, we collect issues reported by users and post the corresponding
+solutions.
+
+
+OSError: libtorch_hip.so: cannot open shared object file: no such file or directory
+-----------------------------------------------------------------------------------
+
+One user is using the following code to install ``torch`` and ``torchaudio``:
+
+.. code-block:: bash
+
+ pip install \
+ torch==1.10.0+cu111 \
+ torchvision==0.11.0+cu111 \
+ torchaudio==0.10.0 \
+ -f https://download.pytorch.org/whl/torch_stable.html
+
+and it throws the following error when running ``tdnn/train.py``:
+
+.. code-block::
+
+ OSError: libtorch_hip.so: cannot open shared object file: no such file or directory
+
+The fix is to specify the CUDA version while installing ``torchaudio``. That
+is, change ``torchaudio==0.10.0`` to ``torchaudio==0.10.0+cu11```. Therefore,
+the correct command is:
+
+.. code-block:: bash
+
+ pip install \
+ torch==1.10.0+cu111 \
+ torchvision==0.11.0+cu111 \
+ torchaudio==0.10.0+cu111 \
+ -f https://download.pytorch.org/whl/torch_stable.html
+
+AttributeError: module 'distutils' has no attribute 'version'
+-------------------------------------------------------------
+
+The error log is:
+
+.. code-block::
+
+ Traceback (most recent call last):
+ File "./tdnn/train.py", line 14, in
+ from asr_datamodule import YesNoAsrDataModule
+ File "/home/xxx/code/next-gen-kaldi/icefall/egs/yesno/ASR/tdnn/asr_datamodule.py", line 34, in
+ from icefall.dataset.datamodule import DataModule
+ File "/home/xxx/code/next-gen-kaldi/icefall/icefall/__init__.py", line 3, in
+ from . import (
+ File "/home/xxx/code/next-gen-kaldi/icefall/icefall/decode.py", line 23, in
+ from icefall.utils import add_eos, add_sos, get_texts
+ File "/home/xxx/code/next-gen-kaldi/icefall/icefall/utils.py", line 39, in
+ from torch.utils.tensorboard import SummaryWriter
+ File "/home/xxx/tool/miniconda3/envs/yyy/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py", line 4, in
+ LooseVersion = distutils.version.LooseVersion
+ AttributeError: module 'distutils' has no attribute 'version'
+
+The fix is:
+
+.. code-block:: bash
+
+ pip uninstall setuptools
+
+ pip install setuptools==58.0.4
+
+ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory
+--------------------------------------------------------------------------------------------
+
+If you are using ``conda`` and encounter the following issue:
+
+.. code-block::
+
+ Traceback (most recent call last):
+ File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 24, in
+ from _k2 import DeterminizeWeightPushingType
+ ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory
+
+ During handling of the above exception, another exception occurred:
+
+ Traceback (most recent call last):
+ File "/k2-dev/yangyifan/icefall/egs/librispeech/ASR/./pruned_transducer_stateless7_ctc_bs/decode.py", line 104, in
+ import k2
+ File "/k2-dev/yangyifan/anaconda3/envs/icefall/lib/python3.10/site-packages/k2-1.23.3.dev20230112+cuda11.6.torch1.13.1-py3.10-linux-x86_64.egg/k2/__init__.py", line 30, in
+ raise ImportError(
+ ImportError: libpython3.10.so.1.0: cannot open shared object file: No such file or directory
+ Note: If you're using anaconda and importing k2 on MacOS,
+ you can probably fix this by setting the environment variable:
+ export DYLD_LIBRARY_PATH=$CONDA_PREFIX/lib/python3.10/site-packages:$DYLD_LIBRARY_PATH
+
+Please first try to find where ``libpython3.10.so.1.0`` locates.
+
+For instance,
+
+.. code-block:: bash
+
+ cd $CONDA_PREFIX/lib
+ find . -name "libpython*"
+
+If you are able to find it inside ``$CODNA_PREFIX/lib``, please set the
+following environment variable:
+
+.. code-block:: bash
+
+ export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
diff --git a/docs/source/index.rst b/docs/source/index.rst
index be9977ca9..8d76eb68b 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -21,7 +21,16 @@ speech recognition recipes using `k2 `_.
:caption: Contents:
installation/index
+ faqs
model-export/index
+
+.. toctree::
+ :maxdepth: 3
+
recipes/index
+
+.. toctree::
+ :maxdepth: 2
+
contributing/index
huggingface/index
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
index 534b2e534..3019ff03d 100644
--- a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
index 4254dc58a..df677ad09 100644
--- a/docs/source/installation/images/python-gt-v3.6-blue.svg
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
index d3ece9a17..d7007d742 100644
--- a/docs/source/installation/images/torch-gt-v1.6.0-green.svg
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index c4474c3d9..5b9fb2664 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -393,6 +393,17 @@ Now let us run the training part:
We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU
even if there are GPUs available.
+.. hint::
+
+ In case you get a ``Segmentation fault (core dump)`` error, please use:
+
+ .. code-block:: bash
+
+ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+ See more at `` if you are
+ interested.
+
The training log is given below:
.. code-block::
diff --git a/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt
new file mode 100644
index 000000000..ecbdd4b31
--- /dev/null
+++ b/docs/source/model-export/code/export-conv-emformer-transducer-for-ncnn-output.txt
@@ -0,0 +1,21 @@
+2023-01-11 12:15:38,677 INFO [export-for-ncnn.py:220] device: cpu
+2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:229] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_v
+alid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampl
+ing_factor': 4, 'decoder_dim': 512, 'joiner_dim': 512, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.2', 'k2-build-type':
+'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'a34171ed85605b0926eebbd0463d059431f4f74a', 'k2-git-date': 'Wed Dec 14 00:06:38 2022',
+ 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-vers
+ion': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'fix-stateless3-train-2022-12-27', 'icefall-git-sha1': '530e8a1-dirty', '
+icefall-git-date': 'Tue Dec 27 13:59:18 2022', 'icefall-path': '/star-fj/fangjun/open-source/icefall', 'k2-path': '/star-fj/fangjun/op
+en-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279
+-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '127.0.0.1'}, 'epoch': 30, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefa
+ll-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp'), 'bpe_model': './icefall-asr-librispeech-conv-emformer-transdu
+cer-stateless2-2022-07-05//data/lang_bpe_500/bpe.model', 'jit': False, 'context_size': 2, 'use_averaged_model': False, 'encoder_dim':
+512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'cnn_module_kernel': 31, 'left_context_length': 32, 'chunk_length'
+: 32, 'right_context_length': 8, 'memory_size': 32, 'blank_id': 0, 'vocab_size': 500}
+2023-01-11 12:15:38,681 INFO [export-for-ncnn.py:231] About to create model
+2023-01-11 12:15:40,053 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-conv-emformer-transducer-stateless2-2
+022-07-05/exp/epoch-30.pt
+2023-01-11 12:15:40,708 INFO [export-for-ncnn.py:315] Number of model parameters: 75490012
+2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:318] Using torch.jit.trace()
+2023-01-11 12:15:41,681 INFO [export-for-ncnn.py:320] Exporting encoder
+2023-01-11 12:15:41,682 INFO [export-for-ncnn.py:149] chunk_length: 32, right_context_length: 8
diff --git a/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt
new file mode 100644
index 000000000..fe4460985
--- /dev/null
+++ b/docs/source/model-export/code/export-lstm-transducer-for-ncnn-output.txt
@@ -0,0 +1,18 @@
+2023-02-17 11:22:42,862 INFO [export-for-ncnn.py:222] device: cpu
+2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:231] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'dim_feedforward': 2048, 'decoder_dim': 512, 'joiner_dim': 512, 'is_pnnx': False, 'model_warm_step': 3000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-dirty', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp'), 'bpe_model': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': 12, 'encoder_dim': 512, 'rnn_hidden_size': 1024, 'aux_layer_period': 0, 'blank_id': 0, 'vocab_size': 500}
+2023-02-17 11:22:42,865 INFO [export-for-ncnn.py:235] About to create model
+2023-02-17 11:22:43,239 INFO [train.py:472] Disable giga
+2023-02-17 11:22:43,249 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/epoch-99.pt
+2023-02-17 11:22:44,595 INFO [export-for-ncnn.py:324] encoder parameters: 83137520
+2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:325] decoder parameters: 257024
+2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:326] joiner parameters: 781812
+2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:327] total parameters: 84176356
+2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:329] Using torch.jit.trace()
+2023-02-17 11:22:44,596 INFO [export-for-ncnn.py:331] Exporting encoder
+2023-02-17 11:22:48,182 INFO [export-for-ncnn.py:158] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt
+2023-02-17 11:22:48,183 INFO [export-for-ncnn.py:335] Exporting decoder
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/lstm_transducer_stateless2/decoder.py:101: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ need_pad = bool(need_pad)
+2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:180] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt
+2023-02-17 11:22:48,259 INFO [export-for-ncnn.py:339] Exporting joiner
+2023-02-17 11:22:48,304 INFO [export-for-ncnn.py:207] Saved to icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt
diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt
new file mode 100644
index 000000000..25874a414
--- /dev/null
+++ b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt
@@ -0,0 +1,74 @@
+2023-02-27 20:23:07,473 INFO [export-for-ncnn.py:246] device: cpu
+2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:255] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-clean', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp'), 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': '2,4,3,2,4', 'feedforward_dims': '1024,1024,2048,2048,1024', 'nhead': '8,8,8,8,8', 'encoder_dims': '384,384,384,384,384', 'attention_dims': '192,192,192,192,192', 'encoder_unmasked_dims': '256,256,256,256,256', 'zipformer_downsampling_factors': '1,2,4,8,2', 'cnn_module_kernels': '31,31,31,31,31', 'decoder_dim': 512, 'joiner_dim': 512, 'short_chunk_size': 50, 'num_left_chunks': 4, 'decode_chunk_len': 32, 'blank_id': 0, 'vocab_size': 500}
+2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:257] About to create model
+2023-02-27 20:23:08,023 INFO [zipformer2.py:419] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8.
+2023-02-27 20:23:08,037 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/epoch-99.pt
+2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:346] encoder parameters: 68944004
+2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:347] decoder parameters: 260096
+2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:348] joiner parameters: 716276
+2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:349] total parameters: 69920376
+2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:351] Using torch.jit.trace()
+2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:353] Exporting encoder
+2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:174] decode_chunk_len: 32
+2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:175] T: 39
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1344: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_len.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_avg.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1352: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_key.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_val.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1360: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_val2.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1364: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_conv1.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1368: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_conv2.size(0) == self.num_layers, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert self.left_context_len == cached_key.shape[1], (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1884: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert self.x_size == x.size(0), (self.x_size, x.size(0))
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2442: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_key.shape[0] == self.left_context_len, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2449: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_key.shape[0] == cached_val.shape[0], (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2469: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_key.shape[0] == left_context_len, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2473: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_val.shape[0] == left_context_len, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2483: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert kv_len == k.shape[0], (kv_len, k.shape)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2926: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cache.shape == (x.size(0), x.size(1), self.lorder), (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2652: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert x.shape[0] == self.x_size, (x.shape[0], self.x_size)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2653: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2666: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert cached_val.shape[0] == self.left_context_len, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1543: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1637: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src.shape[0] == self.in_x_size, (
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1643: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1571: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ if src.shape[0] != self.in_x_size:
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1763: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1779: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1)
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2)
+/star-fj/fangjun/py38/lib/python3.8/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
+ module._c._create_method_from_trace(
+2023-02-27 20:23:19,640 INFO [export-for-ncnn.py:182] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt
+2023-02-27 20:23:19,646 INFO [export-for-ncnn.py:357] Exporting decoder
+/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py:102: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
+ assert embedding_out.size(-1) == self.context_size
+2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:204] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt
+2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:361] Exporting joiner
+2023-02-27 20:23:19,735 INFO [export-for-ncnn.py:231] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt
diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt
new file mode 100644
index 000000000..347e7e51a
--- /dev/null
+++ b/docs/source/model-export/code/generate-int-8-scale-table-for-conv-emformer.txt
@@ -0,0 +1,104 @@
+Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1
+num encoder conv layers: 88
+num joiner conv layers: 3
+num files: 3
+Processing ../test_wavs/1089-134686-0001.wav
+Processing ../test_wavs/1221-135766-0001.wav
+Processing ../test_wavs/1221-135766-0002.wav
+Processing ../test_wavs/1089-134686-0001.wav
+Processing ../test_wavs/1221-135766-0001.wav
+Processing ../test_wavs/1221-135766-0002.wav
+----------encoder----------
+conv_87 : max = 15.942385 threshold = 15.938493 scale = 7.968131
+conv_88 : max = 35.442448 threshold = 15.549335 scale = 8.167552
+conv_89 : max = 23.228289 threshold = 8.001738 scale = 15.871552
+linear_90 : max = 3.976146 threshold = 1.101789 scale = 115.267128
+linear_91 : max = 6.962030 threshold = 5.162033 scale = 24.602713
+linear_92 : max = 12.323041 threshold = 3.853959 scale = 32.953129
+linear_94 : max = 6.905416 threshold = 4.648006 scale = 27.323545
+linear_93 : max = 6.905416 threshold = 5.474093 scale = 23.200188
+linear_95 : max = 1.888012 threshold = 1.403563 scale = 90.483986
+linear_96 : max = 6.856741 threshold = 5.398679 scale = 23.524273
+linear_97 : max = 9.635942 threshold = 2.613655 scale = 48.590950
+linear_98 : max = 6.460340 threshold = 5.670146 scale = 22.398010
+linear_99 : max = 9.532276 threshold = 2.585537 scale = 49.119396
+linear_101 : max = 6.585871 threshold = 5.719224 scale = 22.205809
+linear_100 : max = 6.585871 threshold = 5.751382 scale = 22.081648
+linear_102 : max = 1.593344 threshold = 1.450581 scale = 87.551147
+linear_103 : max = 6.592681 threshold = 5.705824 scale = 22.257959
+linear_104 : max = 8.752957 threshold = 1.980955 scale = 64.110489
+linear_105 : max = 6.696240 threshold = 5.877193 scale = 21.608953
+linear_106 : max = 9.059659 threshold = 2.643138 scale = 48.048950
+linear_108 : max = 6.975461 threshold = 4.589567 scale = 27.671457
+linear_107 : max = 6.975461 threshold = 6.190381 scale = 20.515701
+linear_109 : max = 3.710759 threshold = 2.305635 scale = 55.082436
+linear_110 : max = 7.531228 threshold = 5.731162 scale = 22.159557
+linear_111 : max = 10.528083 threshold = 2.259322 scale = 56.211544
+linear_112 : max = 8.148807 threshold = 5.500842 scale = 23.087374
+linear_113 : max = 8.592566 threshold = 1.948851 scale = 65.166611
+linear_115 : max = 8.437109 threshold = 5.608947 scale = 22.642395
+linear_114 : max = 8.437109 threshold = 6.193942 scale = 20.503904
+linear_116 : max = 3.966980 threshold = 3.200896 scale = 39.676392
+linear_117 : max = 9.451303 threshold = 6.061664 scale = 20.951344
+linear_118 : max = 12.077262 threshold = 3.965800 scale = 32.023804
+linear_119 : max = 9.671615 threshold = 4.847613 scale = 26.198460
+linear_120 : max = 8.625638 threshold = 3.131427 scale = 40.556595
+linear_122 : max = 10.274080 threshold = 4.888716 scale = 25.978189
+linear_121 : max = 10.274080 threshold = 5.420480 scale = 23.429659
+linear_123 : max = 4.826197 threshold = 3.599617 scale = 35.281532
+linear_124 : max = 11.396383 threshold = 7.325849 scale = 17.335875
+linear_125 : max = 9.337198 threshold = 3.941410 scale = 32.221970
+linear_126 : max = 9.699965 threshold = 4.842878 scale = 26.224073
+linear_127 : max = 8.775370 threshold = 3.884215 scale = 32.696438
+linear_129 : max = 9.872276 threshold = 4.837319 scale = 26.254213
+linear_128 : max = 9.872276 threshold = 7.180057 scale = 17.687883
+linear_130 : max = 4.150427 threshold = 3.454298 scale = 36.765789
+linear_131 : max = 11.112692 threshold = 7.924847 scale = 16.025545
+linear_132 : max = 11.852893 threshold = 3.116593 scale = 40.749626
+linear_133 : max = 11.517084 threshold = 5.024665 scale = 25.275314
+linear_134 : max = 10.683807 threshold = 3.878618 scale = 32.743618
+linear_136 : max = 12.421055 threshold = 6.322729 scale = 20.086264
+linear_135 : max = 12.421055 threshold = 5.309880 scale = 23.917679
+linear_137 : max = 4.827781 threshold = 3.744595 scale = 33.915554
+linear_138 : max = 14.422395 threshold = 7.742882 scale = 16.402161
+linear_139 : max = 8.527538 threshold = 3.866123 scale = 32.849449
+linear_140 : max = 12.128619 threshold = 4.657793 scale = 27.266134
+linear_141 : max = 9.839593 threshold = 3.845993 scale = 33.021378
+linear_143 : max = 12.442304 threshold = 7.099039 scale = 17.889746
+linear_142 : max = 12.442304 threshold = 5.325038 scale = 23.849592
+linear_144 : max = 5.929444 threshold = 5.618206 scale = 22.605080
+linear_145 : max = 13.382126 threshold = 9.321095 scale = 13.625010
+linear_146 : max = 9.894987 threshold = 3.867645 scale = 32.836517
+linear_147 : max = 10.915313 threshold = 4.906028 scale = 25.886522
+linear_148 : max = 9.614287 threshold = 3.908151 scale = 32.496181
+linear_150 : max = 11.724932 threshold = 4.485588 scale = 28.312899
+linear_149 : max = 11.724932 threshold = 5.161146 scale = 24.606939
+linear_151 : max = 7.164453 threshold = 5.847355 scale = 21.719223
+linear_152 : max = 13.086471 threshold = 5.984121 scale = 21.222834
+linear_153 : max = 11.099524 threshold = 3.991601 scale = 31.816805
+linear_154 : max = 10.054585 threshold = 4.489706 scale = 28.286930
+linear_155 : max = 12.389185 threshold = 3.100321 scale = 40.963501
+linear_157 : max = 9.982999 threshold = 5.154796 scale = 24.637253
+linear_156 : max = 9.982999 threshold = 8.537706 scale = 14.875190
+linear_158 : max = 8.420287 threshold = 6.502287 scale = 19.531588
+linear_159 : max = 25.014746 threshold = 9.423280 scale = 13.477261
+linear_160 : max = 45.633553 threshold = 5.715335 scale = 22.220921
+linear_161 : max = 20.371849 threshold = 5.117830 scale = 24.815203
+linear_162 : max = 12.492933 threshold = 3.126283 scale = 40.623318
+linear_164 : max = 20.697504 threshold = 4.825712 scale = 26.317358
+linear_163 : max = 20.697504 threshold = 5.078367 scale = 25.008038
+linear_165 : max = 9.023975 threshold = 6.836278 scale = 18.577358
+linear_166 : max = 34.860619 threshold = 7.259792 scale = 17.493614
+linear_167 : max = 30.380934 threshold = 5.496160 scale = 23.107042
+linear_168 : max = 20.691216 threshold = 4.733317 scale = 26.831076
+linear_169 : max = 9.723948 threshold = 3.952728 scale = 32.129707
+linear_171 : max = 21.034811 threshold = 5.366547 scale = 23.665123
+linear_170 : max = 21.034811 threshold = 5.356277 scale = 23.710501
+linear_172 : max = 10.556884 threshold = 5.729481 scale = 22.166058
+linear_173 : max = 20.033039 threshold = 10.207264 scale = 12.442120
+linear_174 : max = 11.597379 threshold = 2.658676 scale = 47.768131
+----------joiner----------
+linear_2 : max = 19.293503 threshold = 14.305265 scale = 8.877850
+linear_1 : max = 10.812222 threshold = 8.766452 scale = 14.487047
+linear_3 : max = 0.999999 threshold = 0.999755 scale = 127.031174
+ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233...
diff --git a/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt
new file mode 100644
index 000000000..d39215b14
--- /dev/null
+++ b/docs/source/model-export/code/generate-int-8-scale-table-for-lstm.txt
@@ -0,0 +1,44 @@
+Don't Use GPU. has_gpu: 0, config.use_vulkan_compute: 1
+num encoder conv layers: 28
+num joiner conv layers: 3
+num files: 3
+Processing ../test_wavs/1089-134686-0001.wav
+Processing ../test_wavs/1221-135766-0001.wav
+Processing ../test_wavs/1221-135766-0002.wav
+Processing ../test_wavs/1089-134686-0001.wav
+Processing ../test_wavs/1221-135766-0001.wav
+Processing ../test_wavs/1221-135766-0002.wav
+----------encoder----------
+conv_15 : max = 15.942385 threshold = 15.930708 scale = 7.972025
+conv_16 : max = 44.978855 threshold = 17.031788 scale = 7.456645
+conv_17 : max = 17.868437 threshold = 7.830528 scale = 16.218575
+linear_18 : max = 3.107259 threshold = 1.194808 scale = 106.293236
+linear_19 : max = 6.193777 threshold = 4.634748 scale = 27.401705
+linear_20 : max = 9.259933 threshold = 2.606617 scale = 48.722160
+linear_21 : max = 5.186600 threshold = 4.790260 scale = 26.512129
+linear_22 : max = 9.759041 threshold = 2.265832 scale = 56.050053
+linear_23 : max = 3.931209 threshold = 3.099090 scale = 40.979767
+linear_24 : max = 10.324160 threshold = 2.215561 scale = 57.321835
+linear_25 : max = 3.800708 threshold = 3.599352 scale = 35.284134
+linear_26 : max = 10.492444 threshold = 3.153369 scale = 40.274391
+linear_27 : max = 3.660161 threshold = 2.720994 scale = 46.674126
+linear_28 : max = 9.415265 threshold = 3.174434 scale = 40.007133
+linear_29 : max = 4.038418 threshold = 3.118534 scale = 40.724262
+linear_30 : max = 10.072084 threshold = 3.936867 scale = 32.259155
+linear_31 : max = 4.342712 threshold = 3.599489 scale = 35.282787
+linear_32 : max = 11.340535 threshold = 3.120308 scale = 40.701103
+linear_33 : max = 3.846987 threshold = 3.630030 scale = 34.985939
+linear_34 : max = 10.686298 threshold = 2.204571 scale = 57.607586
+linear_35 : max = 4.904821 threshold = 4.575518 scale = 27.756420
+linear_36 : max = 11.806659 threshold = 2.585589 scale = 49.118401
+linear_37 : max = 6.402340 threshold = 5.047157 scale = 25.162680
+linear_38 : max = 11.174589 threshold = 1.923361 scale = 66.030258
+linear_39 : max = 16.178576 threshold = 7.556058 scale = 16.807705
+linear_40 : max = 12.901954 threshold = 5.301267 scale = 23.956539
+linear_41 : max = 14.839805 threshold = 7.597429 scale = 16.716181
+linear_42 : max = 10.178945 threshold = 2.651595 scale = 47.895699
+----------joiner----------
+linear_2 : max = 24.829245 threshold = 16.627592 scale = 7.637907
+linear_1 : max = 10.746186 threshold = 5.255032 scale = 24.167313
+linear_3 : max = 1.000000 threshold = 0.999756 scale = 127.031013
+ncnn int8 calibration table create success, best wish for your int8 inference has a low accuracy loss...\(^0^)/...233...
diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt
new file mode 100644
index 000000000..114fe7342
--- /dev/null
+++ b/docs/source/model-export/code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt
@@ -0,0 +1,7 @@
+2023-01-11 14:02:12,216 INFO [streaming-ncnn-decode.py:320] {'tokens': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav'}
+T 51 32
+2023-01-11 14:02:13,141 INFO [streaming-ncnn-decode.py:328] Constructing Fbank computer
+2023-01-11 14:02:13,151 INFO [streaming-ncnn-decode.py:331] Reading sound files: ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav
+2023-01-11 14:02:13,176 INFO [streaming-ncnn-decode.py:336] torch.Size([106000])
+2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:380] ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav
+2023-01-11 14:02:17,581 INFO [streaming-ncnn-decode.py:381] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt
new file mode 100644
index 000000000..3606eae3d
--- /dev/null
+++ b/docs/source/model-export/code/test-streaming-ncnn-decode-lstm-transducer-libri.txt
@@ -0,0 +1,6 @@
+2023-02-17 11:37:30,861 INFO [streaming-ncnn-decode.py:255] {'tokens': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav'}
+2023-02-17 11:37:31,425 INFO [streaming-ncnn-decode.py:263] Constructing Fbank computer
+2023-02-17 11:37:31,427 INFO [streaming-ncnn-decode.py:266] Reading sound files: ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav
+2023-02-17 11:37:31,431 INFO [streaming-ncnn-decode.py:271] torch.Size([106000])
+2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:342] ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav
+2023-02-17 11:37:34,115 INFO [streaming-ncnn-decode.py:343] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt
new file mode 100644
index 000000000..5b4969e0f
--- /dev/null
+++ b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt
@@ -0,0 +1,7 @@
+2023-02-27 20:43:40,283 INFO [streaming-ncnn-decode.py:349] {'tokens': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav'}
+2023-02-27 20:43:41,260 INFO [streaming-ncnn-decode.py:357] Constructing Fbank computer
+2023-02-27 20:43:41,264 INFO [streaming-ncnn-decode.py:360] Reading sound files: ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav
+2023-02-27 20:43:41,269 INFO [streaming-ncnn-decode.py:365] torch.Size([106000])
+2023-02-27 20:43:41,280 INFO [streaming-ncnn-decode.py:372] number of states: 35
+2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:410] ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav
+2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:411] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst
new file mode 100644
index 000000000..12b370143
--- /dev/null
+++ b/docs/source/model-export/export-ncnn-conv-emformer.rst
@@ -0,0 +1,753 @@
+.. _export_conv_emformer_transducer_models_to_ncnn:
+
+Export ConvEmformer transducer models to ncnn
+=============================================
+
+We use the pre-trained model from the following repository as an example:
+
+ - ``_
+
+We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_.
+
+.. hint::
+
+ We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing.
+
+.. caution::
+
+ Please use a more recent version of PyTorch. For instance, ``torch 1.8``
+ may ``not`` work.
+
+1. Download the pre-trained model
+---------------------------------
+
+.. hint::
+
+ You can also refer to ``_ to download the pre-trained model.
+
+ You have to install `git-lfs`_ before you continue.
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
+
+ git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+
+ cd ..
+
+.. note::
+
+ We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``.
+
+
+In the above code, we downloaded the pre-trained model into the directory
+``egs/librispeech/ASR/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05``.
+
+.. _export_for_ncnn_install_ncnn_and_pnnx:
+
+2. Install ncnn and pnnx
+------------------------
+
+.. code-block:: bash
+
+ # We put ncnn into $HOME/open-source/ncnn
+ # You can change it to anywhere you like
+
+ cd $HOME
+ mkdir -p open-source
+ cd open-source
+
+ git clone https://github.com/csukuangfj/ncnn
+ cd ncnn
+ git submodule update --recursive --init
+
+ # Note: We don't use "python setup.py install" or "pip install ." here
+
+ mkdir -p build-wheel
+ cd build-wheel
+
+ cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DNCNN_PYTHON=ON \
+ -DNCNN_BUILD_BENCHMARK=OFF \
+ -DNCNN_BUILD_EXAMPLES=OFF \
+ -DNCNN_BUILD_TOOLS=ON \
+ ..
+
+ make -j4
+
+ cd ..
+
+ # Note: $PWD here is $HOME/open-source/ncnn
+
+ export PYTHONPATH=$PWD/python:$PYTHONPATH
+ export PATH=$PWD/tools/pnnx/build/src:$PATH
+ export PATH=$PWD/build-wheel/tools/quantize:$PATH
+
+ # Now build pnnx
+ cd tools/pnnx
+ mkdir build
+ cd build
+ cmake ..
+ make -j4
+
+ ./src/pnnx
+
+Congratulations! You have successfully installed the following components:
+
+ - ``pnnx``, which is an executable located in
+ ``$HOME/open-source/ncnn/tools/pnnx/build/src``. We will use
+ it to convert models exported by ``torch.jit.trace()``.
+ - ``ncnn2int8``, which is an executable located in
+ ``$HOME/open-source/ncnn/build-wheel/tools/quantize``. We will use
+ it to quantize our models to ``int8``.
+ - ``ncnn.cpython-38-x86_64-linux-gnu.so``, which is a Python module located
+ in ``$HOME/open-source/ncnn/python/ncnn``.
+
+ .. note::
+
+ I am using ``Python 3.8``, so it
+ is ``ncnn.cpython-38-x86_64-linux-gnu.so``. If you use a different
+ version, say, ``Python 3.9``, the name would be
+ ``ncnn.cpython-39-x86_64-linux-gnu.so``.
+
+ Also, if you are not using Linux, the file name would also be different.
+ But that does not matter. As long as you can compile it, it should work.
+
+We have set up ``PYTHONPATH`` so that you can use ``import ncnn`` in your
+Python code. We have also set up ``PATH`` so that you can use
+``pnnx`` and ``ncnn2int8`` later in your terminal.
+
+.. caution::
+
+ Please don't use ``_.
+ We have made some modifications to the offical `ncnn`_.
+
+ We will synchronize ``_ periodically
+ with the official one.
+
+3. Export the model via torch.jit.trace()
+-----------------------------------------
+
+First, let us rename our pre-trained model:
+
+.. code-block::
+
+ cd egs/librispeech/ASR
+
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp
+
+ ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt
+
+ cd ../..
+
+Next, we use the following code to export our model:
+
+.. code-block:: bash
+
+ dir=./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/
+
+ ./conv_emformer_transducer_stateless2/export-for-ncnn.py \
+ --exp-dir $dir/exp \
+ --bpe-model $dir/data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --encoder-dim 512
+
+.. caution::
+
+ If your model has different configuration parameters, please change them accordingly.
+
+.. hint::
+
+ We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``.
+ There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``.
+
+ If you have trained a model by yourself and if you have all checkpoints
+ available, please first use ``decode.py`` to tune ``--epoch --avg``
+ and select the best combination with with ``--use-averaged-model 1``.
+
+.. note::
+
+ You will see the following log output:
+
+ .. literalinclude:: ./code/export-conv-emformer-transducer-for-ncnn-output.txt
+
+ The log shows the model has ``75490012`` parameters, i.e., ``~75 M``.
+
+ .. code-block::
+
+ ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt
+
+ -rw-r--r-- 1 kuangfangjun root 289M Jan 11 12:05 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt
+
+ You can see that the file size of the pre-trained model is ``289 MB``, which
+ is roughly equal to ``75490012*4/1024/1024 = 287.97 MB``.
+
+After running ``conv_emformer_transducer_stateless2/export-for-ncnn.py``,
+we will get the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*pnnx*
+
+ -rw-r--r-- 1 kuangfangjun root 1010K Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 283M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 12:15 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.pt
+
+
+.. _conv-emformer-step-4-export-torchscript-model-via-pnnx:
+
+4. Export torchscript model via pnnx
+------------------------------------
+
+.. hint::
+
+ Make sure you have set up the ``PATH`` environment variable. Otherwise,
+ it will throw an error saying that ``pnnx`` could not be found.
+
+Now, it's time to export our models to `ncnn`_ via ``pnnx``.
+
+.. code-block::
+
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ pnnx ./encoder_jit_trace-pnnx.pt
+ pnnx ./decoder_jit_trace-pnnx.pt
+ pnnx ./joiner_jit_trace-pnnx.pt
+
+It will generate the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*ncnn*{bin,param}
+
+ -rw-r--r-- 1 kuangfangjun root 503K Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 437 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 142M Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 79K Jan 11 12:36 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 1.5M Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 488 Jan 11 12:38 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param
+
+There are two types of files:
+
+- ``param``: It is a text file containing the model architectures. You can
+ use a text editor to view its content.
+- ``bin``: It is a binary file containing the model parameters.
+
+We compare the file sizes of the models below before and after converting via ``pnnx``:
+
+.. see https://tableconvert.com/restructuredtext-generator
+
++----------------------------------+------------+
+| File name | File size |
++==================================+============+
+| encoder_jit_trace-pnnx.pt | 283 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin | 142 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin | 503 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB |
++----------------------------------+------------+
+
+You can see that the file sizes of the models after conversion are about one half
+of the models before conversion:
+
+ - encoder: 283 MB vs 142 MB
+ - decoder: 1010 KB vs 503 KB
+ - joiner: 3.0 MB vs 1.5 MB
+
+The reason is that by default ``pnnx`` converts ``float32`` parameters
+to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes
+for ``float16``. Thus, it is ``twice smaller`` after conversion.
+
+.. hint::
+
+ If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx``
+ won't convert ``float32`` to ``float16``.
+
+5. Test the exported models in icefall
+--------------------------------------
+
+.. note::
+
+ We assume you have set up the environment variable ``PYTHONPATH`` when
+ building `ncnn`_.
+
+Now we have successfully converted our pre-trained model to `ncnn`_ format.
+The generated 6 files are what we need. You can use the following code to
+test the converted models:
+
+.. code-block:: bash
+
+ ./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ ./icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/test_wavs/1089-134686-0001.wav
+
+.. hint::
+
+ `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts
+ only 1 wave file as input.
+
+The output is given below:
+
+.. literalinclude:: ./code/test-streaming-ncnn-decode-conv-emformer-transducer-libri.txt
+
+Congratulations! You have successfully exported a model from PyTorch to `ncnn`_!
+
+
+.. _conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn:
+
+6. Modify the exported encoder for sherpa-ncnn
+----------------------------------------------
+
+In order to use the exported models in `sherpa-ncnn`_, we have to modify
+``encoder_jit_trace-pnnx.ncnn.param``.
+
+Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``:
+
+.. code-block::
+
+ 7767517
+ 1060 1342
+ Input in0 0 1 in0
+
+**Explanation** of the above three lines:
+
+ 1. ``7767517``, it is a magic number and should not be changed.
+ 2. ``1060 1342``, the first number ``1060`` specifies the number of layers
+ in this file, while ``1342`` specifies the number of intermediate outputs
+ of this file
+ 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0``
+ is the layer name of this layer; ``0`` means this layer has no input;
+ ``1`` means this layer has one output; ``in0`` is the output name of
+ this layer.
+
+We need to add 1 extra line and also increment the number of layers.
+The result looks like below:
+
+.. code-block:: bash
+
+ 7767517
+ 1061 1342
+ SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512
+ Input in0 0 1 in0
+
+**Explanation**
+
+ 1. ``7767517``, it is still the same
+ 2. ``1061 1342``, we have added an extra layer, so we need to update ``1060`` to ``1061``.
+ We don't need to change ``1342`` since the newly added layer has no inputs or outputs.
+ 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512``
+ This line is newly added. Its explanation is given below:
+
+ - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``.
+ - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``.
+ - ``0 0`` means this layer has no inputs or output. Must be ``0 0``
+ - ``0=1``, 0 is the key and 1 is the value. MUST be ``0=1``
+ - ``1=12``, 1 is the key and 12 is the value of the
+ parameter ``--num-encoder-layers`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``2=32``, 2 is the key and 32 is the value of the
+ parameter ``--memory-size`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``3=31``, 3 is the key and 31 is the value of the
+ parameter ``--cnn-module-kernel`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``4=8``, 4 is the key and 8 is the value of the
+ parameter ``--left-context-length`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``5=32``, 5 is the key and 32 is the value of the
+ parameter ``--chunk-length`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``6=8``, 6 is the key and 8 is the value of the
+ parameter ``--right-context-length`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+ - ``7=512``, 7 is the key and 512 is the value of the
+ parameter ``--encoder-dim`` that you provided when running
+ ``conv_emformer_transducer_stateless2/export-for-ncnn.py``.
+
+ For ease of reference, we list the key-value pairs that you need to add
+ in the following table. If your model has a different setting, please
+ change the values for ``SherpaMetaData`` accordingly. Otherwise, you
+ will be ``SAD``.
+
+ +------+-----------------------------+
+ | key | value |
+ +======+=============================+
+ | 0 | 1 (fixed) |
+ +------+-----------------------------+
+ | 1 | ``--num-encoder-layers`` |
+ +------+-----------------------------+
+ | 2 | ``--memory-size`` |
+ +------+-----------------------------+
+ | 3 | ``--cnn-module-kernel`` |
+ +------+-----------------------------+
+ | 4 | ``--left-context-length`` |
+ +------+-----------------------------+
+ | 5 | ``--chunk-length`` |
+ +------+-----------------------------+
+ | 6 | ``--right-context-length`` |
+ +------+-----------------------------+
+ | 7 | ``--encoder-dim`` |
+ +------+-----------------------------+
+
+ 4. ``Input in0 0 1 in0``. No need to change it.
+
+.. caution::
+
+ When you add a new layer ``SherpaMetaData``, please remember to update the
+ number of layers. In our case, update ``1060`` to ``1061``. Otherwise,
+ you will be SAD later.
+
+.. hint::
+
+ After adding the new layer ``SherpaMetaData``, you cannot use this model
+ with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is
+ supported only in `sherpa-ncnn`_.
+
+.. hint::
+
+ `ncnn`_ is very flexible. You can add new layers to it just by text-editing
+ the ``param`` file! You don't need to change the ``bin`` file.
+
+Now you can use this model in `sherpa-ncnn`_.
+Please refer to the following documentation:
+
+ - Linux/macOS/Windows/arm/aarch64: ``_
+ - ``Android``: ``_
+ - ``iOS``: ``_
+ - Python: ``_
+
+We have a list of pre-trained models that have been exported for `sherpa-ncnn`_:
+
+ - ``_
+
+ You can find more usages there.
+
+7. (Optional) int8 quantization with sherpa-ncnn
+------------------------------------------------
+
+This step is optional.
+
+In this step, we describe how to quantize our model with ``int8``.
+
+Change :ref:`conv-emformer-step-4-export-torchscript-model-via-pnnx` to
+disable ``fp16`` when using ``pnnx``:
+
+.. code-block::
+
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ pnnx ./encoder_jit_trace-pnnx.pt fp16=0
+ pnnx ./decoder_jit_trace-pnnx.pt
+ pnnx ./joiner_jit_trace-pnnx.pt fp16=0
+
+.. note::
+
+ We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not
+ support quantizing the decoder model yet. We will update this documentation
+ once `ncnn`_ supports it. (Maybe in this year, 2023).
+
+It will generate the following files
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/*_jit_trace-pnnx.ncnn.{param,bin}
+
+ -rw-r--r-- 1 kuangfangjun root 503K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 437 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/decoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 283M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 79K Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/encoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 3.0M Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 488 Jan 11 15:56 icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/joiner_jit_trace-pnnx.ncnn.param
+
+Let us compare again the file sizes:
+
++----------------------------------------+------------+
+| File name | File size |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.pt | 283 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB |
++----------------------------------------+------------+
+
+You can see that the file sizes are doubled when we disable ``fp16``.
+
+.. note::
+
+ You can again use ``streaming-ncnn-decode.py`` to test the exported models.
+
+Next, follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn`
+to modify ``encoder_jit_trace-pnnx.ncnn.param``.
+
+Change
+
+.. code-block:: bash
+
+ 7767517
+ 1060 1342
+ Input in0 0 1 in0
+
+to
+
+.. code-block:: bash
+
+ 7767517
+ 1061 1342
+ SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 2=32 3=31 4=8 5=32 6=8 7=512
+ Input in0 0 1 in0
+
+.. caution::
+
+ Please follow :ref:`conv-emformer-modify-the-exported-encoder-for-sherpa-ncnn`
+ to change the values for ``SherpaMetaData`` if your model uses a different setting.
+
+
+Next, let us compile `sherpa-ncnn`_ since we will quantize our models within
+`sherpa-ncnn`_.
+
+.. code-block:: bash
+
+ # We will download sherpa-ncnn to $HOME/open-source/
+ # You can change it to anywhere you like.
+ cd $HOME
+ mkdir -p open-source
+
+ cd open-source
+ git clone https://github.com/k2-fsa/sherpa-ncnn
+ cd sherpa-ncnn
+ mkdir build
+ cd build
+ cmake ..
+ make -j 4
+
+ ./bin/generate-int8-scale-table
+
+ export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH
+
+The output of the above commands are:
+
+.. code-block:: bash
+
+ (py38) kuangfangjun:build$ generate-int8-scale-table
+ Please provide 10 arg. Currently given: 1
+ Usage:
+ generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt
+
+ Each line in wave_filenames.txt is a path to some 16k Hz mono wave file.
+
+We need to create a file ``wave_filenames.txt``, in which we need to put
+some calibration wave files. For testing purpose, we put the ``test_wavs``
+from the pre-trained model repository ``_
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ cat < wave_filenames.txt
+ ../test_wavs/1089-134686-0001.wav
+ ../test_wavs/1221-135766-0001.wav
+ ../test_wavs/1221-135766-0002.wav
+ EOF
+
+Now we can calculate the scales needed for quantization with the calibration data:
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ generate-int8-scale-table \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ./encoder-scale-table.txt \
+ ./joiner-scale-table.txt \
+ ./wave_filenames.txt
+
+The output logs are in the following:
+
+.. literalinclude:: ./code/generate-int-8-scale-table-for-conv-emformer.txt
+
+It generates the following two files:
+
+.. code-block:: bash
+
+ $ ls -lh encoder-scale-table.txt joiner-scale-table.txt
+ -rw-r--r-- 1 kuangfangjun root 955K Jan 11 17:28 encoder-scale-table.txt
+ -rw-r--r-- 1 kuangfangjun root 18K Jan 11 17:28 joiner-scale-table.txt
+
+.. caution::
+
+ Definitely, you need more calibration data to compute the scale table.
+
+Finally, let us use the scale table to quantize our models into ``int8``.
+
+.. code-block:: bash
+
+ ncnn2int8
+
+ usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table]
+
+First, we quantize the encoder model:
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ ncnn2int8 \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./encoder_jit_trace-pnnx.ncnn.int8.param \
+ ./encoder_jit_trace-pnnx.ncnn.int8.bin \
+ ./encoder-scale-table.txt
+
+Next, we quantize the joiner model:
+
+.. code-block:: bash
+
+ ncnn2int8 \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.int8.param \
+ ./joiner_jit_trace-pnnx.ncnn.int8.bin \
+ ./joiner-scale-table.txt
+
+The above two commands generate the following 4 files:
+
+.. code-block:: bash
+
+ -rw-r--r-- 1 kuangfangjun root 99M Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.bin
+ -rw-r--r-- 1 kuangfangjun root 78K Jan 11 17:34 encoder_jit_trace-pnnx.ncnn.int8.param
+ -rw-r--r-- 1 kuangfangjun root 774K Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.bin
+ -rw-r--r-- 1 kuangfangjun root 496 Jan 11 17:35 joiner_jit_trace-pnnx.ncnn.int8.param
+
+Congratulations! You have successfully quantized your model from ``float32`` to ``int8``.
+
+.. caution::
+
+ ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs.
+
+ You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param``
+ and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like.
+
+ For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can
+ replace the following invocation:
+
+ .. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ sherpa-ncnn \
+ ../data/lang_bpe_500/tokens.txt \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ../test_wavs/1089-134686-0001.wav
+
+ with
+
+ .. code-block::
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ sherpa-ncnn \
+ ../data/lang_bpe_500/tokens.txt \
+ ./encoder_jit_trace-pnnx.ncnn.int8.param \
+ ./encoder_jit_trace-pnnx.ncnn.int8.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ../test_wavs/1089-134686-0001.wav
+
+
+The following table compares again the file sizes:
+
+
++----------------------------------------+------------+
+| File name | File size |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.pt | 283 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 142 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 283 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.int8.bin | 99 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB |
++----------------------------------------+------------+
+
+You can see that the file sizes of the model after ``int8`` quantization
+are much smaller.
+
+.. hint::
+
+ Currently, only linear layers and convolutional layers are quantized
+ with ``int8``, so you don't see an exact ``4x`` reduction in file sizes.
+
+.. note::
+
+ You need to test the recognition accuracy after ``int8`` quantization.
+
+You can find the speed comparison at ``_.
+
+
+That's it! Have fun with `sherpa-ncnn`_!
diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst
new file mode 100644
index 000000000..8e6dc7466
--- /dev/null
+++ b/docs/source/model-export/export-ncnn-lstm.rst
@@ -0,0 +1,644 @@
+.. _export_lstm_transducer_models_to_ncnn:
+
+Export LSTM transducer models to ncnn
+-------------------------------------
+
+We use the pre-trained model from the following repository as an example:
+
+``_
+
+We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_.
+
+.. hint::
+
+ We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing.
+
+.. caution::
+
+ Please use a more recent version of PyTorch. For instance, ``torch 1.8``
+ may ``not`` work.
+
+1. Download the pre-trained model
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. hint::
+
+ You have to install `git-lfs`_ before you continue.
+
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+
+ git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+
+ cd ..
+
+.. note::
+
+ We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``.
+
+In the above code, we downloaded the pre-trained model into the directory
+``egs/librispeech/ASR/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03``.
+
+2. Install ncnn and pnnx
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` .
+
+
+3. Export the model via torch.jit.trace()
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+First, let us rename our pre-trained model:
+
+.. code-block::
+
+ cd egs/librispeech/ASR
+
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp
+
+ ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
+
+ cd ../..
+
+Next, we use the following code to export our model:
+
+.. code-block:: bash
+
+ dir=./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+
+ ./lstm_transducer_stateless2/export-for-ncnn.py \
+ --exp-dir $dir/exp \
+ --bpe-model $dir/data/lang_bpe_500/bpe.model \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --num-encoder-layers 12 \
+ --encoder-dim 512 \
+ --rnn-hidden-size 1024
+
+.. hint::
+
+ We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``.
+ There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``.
+
+ If you have trained a model by yourself and if you have all checkpoints
+ available, please first use ``decode.py`` to tune ``--epoch --avg``
+ and select the best combination with with ``--use-averaged-model 1``.
+
+.. note::
+
+ You will see the following log output:
+
+ .. literalinclude:: ./code/export-lstm-transducer-for-ncnn-output.txt
+
+ The log shows the model has ``84176356`` parameters, i.e., ``~84 M``.
+
+ .. code-block::
+
+ ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt
+
+ -rw-r--r-- 1 kuangfangjun root 324M Feb 17 10:34 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/pretrained-iter-468000-avg-16.pt
+
+ You can see that the file size of the pre-trained model is ``324 MB``, which
+ is roughly equal to ``84176356*4/1024/1024 = 321.107 MB``.
+
+After running ``lstm_transducer_stateless2/export-for-ncnn.py``,
+we will get the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*pnnx.pt
+
+ -rw-r--r-- 1 kuangfangjun root 1010K Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 318M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:22 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.pt
+
+
+.. _lstm-transducer-step-4-export-torchscript-model-via-pnnx:
+
+4. Export torchscript model via pnnx
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. hint::
+
+ Make sure you have set up the ``PATH`` environment variable
+ in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise,
+ it will throw an error saying that ``pnnx`` could not be found.
+
+Now, it's time to export our models to `ncnn`_ via ``pnnx``.
+
+.. code-block::
+
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ pnnx ./encoder_jit_trace-pnnx.pt
+ pnnx ./decoder_jit_trace-pnnx.pt
+ pnnx ./joiner_jit_trace-pnnx.pt
+
+It will generate the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*ncnn*{bin,param}
+
+ -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 159M Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 1.5M Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:33 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param
+
+
+There are two types of files:
+
+- ``param``: It is a text file containing the model architectures. You can
+ use a text editor to view its content.
+- ``bin``: It is a binary file containing the model parameters.
+
+We compare the file sizes of the models below before and after converting via ``pnnx``:
+
+.. see https://tableconvert.com/restructuredtext-generator
+
++----------------------------------+------------+
+| File name | File size |
++==================================+============+
+| encoder_jit_trace-pnnx.pt | 318 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin | 159 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin | 503 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin | 1.5 MB |
++----------------------------------+------------+
+
+You can see that the file sizes of the models after conversion are about one half
+of the models before conversion:
+
+ - encoder: 318 MB vs 159 MB
+ - decoder: 1010 KB vs 503 KB
+ - joiner: 3.0 MB vs 1.5 MB
+
+The reason is that by default ``pnnx`` converts ``float32`` parameters
+to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes
+for ``float16``. Thus, it is ``twice smaller`` after conversion.
+
+.. hint::
+
+ If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx``
+ won't convert ``float32`` to ``float16``.
+
+5. Test the exported models in icefall
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. note::
+
+ We assume you have set up the environment variable ``PYTHONPATH`` when
+ building `ncnn`_.
+
+Now we have successfully converted our pre-trained model to `ncnn`_ format.
+The generated 6 files are what we need. You can use the following code to
+test the converted models:
+
+.. code-block:: bash
+
+ python3 ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
+ --tokens ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ ./icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/test_wavs/1089-134686-0001.wav
+
+.. hint::
+
+ `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts
+ only 1 wave file as input.
+
+The output is given below:
+
+.. literalinclude:: ./code/test-streaming-ncnn-decode-lstm-transducer-libri.txt
+
+Congratulations! You have successfully exported a model from PyTorch to `ncnn`_!
+
+.. _lstm-modify-the-exported-encoder-for-sherpa-ncnn:
+
+6. Modify the exported encoder for sherpa-ncnn
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to use the exported models in `sherpa-ncnn`_, we have to modify
+``encoder_jit_trace-pnnx.ncnn.param``.
+
+Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``:
+
+.. code-block::
+
+ 7767517
+ 267 379
+ Input in0 0 1 in0
+
+**Explanation** of the above three lines:
+
+ 1. ``7767517``, it is a magic number and should not be changed.
+ 2. ``267 379``, the first number ``267`` specifies the number of layers
+ in this file, while ``379`` specifies the number of intermediate outputs
+ of this file
+ 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0``
+ is the layer name of this layer; ``0`` means this layer has no input;
+ ``1`` means this layer has one output; ``in0`` is the output name of
+ this layer.
+
+We need to add 1 extra line and also increment the number of layers.
+The result looks like below:
+
+.. code-block:: bash
+
+ 7767517
+ 268 379
+ SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024
+ Input in0 0 1 in0
+
+**Explanation**
+
+ 1. ``7767517``, it is still the same
+ 2. ``268 379``, we have added an extra layer, so we need to update ``267`` to ``268``.
+ We don't need to change ``379`` since the newly added layer has no inputs or outputs.
+ 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024``
+ This line is newly added. Its explanation is given below:
+
+ - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``.
+ - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``.
+ - ``0 0`` means this layer has no inputs or output. Must be ``0 0``
+ - ``0=3``, 0 is the key and 3 is the value. MUST be ``0=3``
+ - ``1=12``, 1 is the key and 12 is the value of the
+ parameter ``--num-encoder-layers`` that you provided when running
+ ``./lstm_transducer_stateless2/export-for-ncnn.py``.
+ - ``2=512``, 2 is the key and 512 is the value of the
+ parameter ``--encoder-dim`` that you provided when running
+ ``./lstm_transducer_stateless2/export-for-ncnn.py``.
+ - ``3=1024``, 3 is the key and 1024 is the value of the
+ parameter ``--rnn-hidden-size`` that you provided when running
+ ``./lstm_transducer_stateless2/export-for-ncnn.py``.
+
+ For ease of reference, we list the key-value pairs that you need to add
+ in the following table. If your model has a different setting, please
+ change the values for ``SherpaMetaData`` accordingly. Otherwise, you
+ will be ``SAD``.
+
+ +------+-----------------------------+
+ | key | value |
+ +======+=============================+
+ | 0 | 3 (fixed) |
+ +------+-----------------------------+
+ | 1 | ``--num-encoder-layers`` |
+ +------+-----------------------------+
+ | 2 | ``--encoder-dim`` |
+ +------+-----------------------------+
+ | 3 | ``--rnn-hidden-size`` |
+ +------+-----------------------------+
+
+ 4. ``Input in0 0 1 in0``. No need to change it.
+
+.. caution::
+
+ When you add a new layer ``SherpaMetaData``, please remember to update the
+ number of layers. In our case, update ``267`` to ``268``. Otherwise,
+ you will be SAD later.
+
+.. hint::
+
+ After adding the new layer ``SherpaMetaData``, you cannot use this model
+ with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is
+ supported only in `sherpa-ncnn`_.
+
+.. hint::
+
+ `ncnn`_ is very flexible. You can add new layers to it just by text-editing
+ the ``param`` file! You don't need to change the ``bin`` file.
+
+Now you can use this model in `sherpa-ncnn`_.
+Please refer to the following documentation:
+
+ - Linux/macOS/Windows/arm/aarch64: ``_
+ - ``Android``: ``_
+ - ``iOS``: ``_
+ - Python: ``_
+
+We have a list of pre-trained models that have been exported for `sherpa-ncnn`_:
+
+ - ``_
+
+ You can find more usages there.
+
+7. (Optional) int8 quantization with sherpa-ncnn
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This step is optional.
+
+In this step, we describe how to quantize our model with ``int8``.
+
+Change :ref:`lstm-transducer-step-4-export-torchscript-model-via-pnnx` to
+disable ``fp16`` when using ``pnnx``:
+
+.. code-block::
+
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ pnnx ./encoder_jit_trace-pnnx.pt fp16=0
+ pnnx ./decoder_jit_trace-pnnx.pt
+ pnnx ./joiner_jit_trace-pnnx.pt fp16=0
+
+.. note::
+
+ We add ``fp16=0`` when exporting the encoder and joiner. `ncnn`_ does not
+ support quantizing the decoder model yet. We will update this documentation
+ once `ncnn`_ supports it. (Maybe in this year, 2023).
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/*_jit_trace-pnnx.ncnn.{param,bin}
+
+ -rw-r--r-- 1 kuangfangjun root 503K Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 437 Feb 17 11:32 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/decoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 317M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 21K Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/encoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 3.0M Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 488 Feb 17 11:54 icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/joiner_jit_trace-pnnx.ncnn.param
+
+
+Let us compare again the file sizes:
+
++----------------------------------------+------------+
+| File name | File size |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.pt | 318 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB |
++----------------------------------------+------------+
+
+You can see that the file sizes are doubled when we disable ``fp16``.
+
+.. note::
+
+ You can again use ``streaming-ncnn-decode.py`` to test the exported models.
+
+Next, follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn`
+to modify ``encoder_jit_trace-pnnx.ncnn.param``.
+
+Change
+
+.. code-block:: bash
+
+ 7767517
+ 267 379
+ Input in0 0 1 in0
+
+to
+
+.. code-block:: bash
+
+ 7767517
+ 268 379
+ SherpaMetaData sherpa_meta_data1 0 0 0=3 1=12 2=512 3=1024
+ Input in0 0 1 in0
+
+.. caution::
+
+ Please follow :ref:`lstm-modify-the-exported-encoder-for-sherpa-ncnn`
+ to change the values for ``SherpaMetaData`` if your model uses a different setting.
+
+Next, let us compile `sherpa-ncnn`_ since we will quantize our models within
+`sherpa-ncnn`_.
+
+.. code-block:: bash
+
+ # We will download sherpa-ncnn to $HOME/open-source/
+ # You can change it to anywhere you like.
+ cd $HOME
+ mkdir -p open-source
+
+ cd open-source
+ git clone https://github.com/k2-fsa/sherpa-ncnn
+ cd sherpa-ncnn
+ mkdir build
+ cd build
+ cmake ..
+ make -j 4
+
+ ./bin/generate-int8-scale-table
+
+ export PATH=$HOME/open-source/sherpa-ncnn/build/bin:$PATH
+
+The output of the above commands are:
+
+.. code-block:: bash
+
+ (py38) kuangfangjun:build$ generate-int8-scale-table
+ Please provide 10 arg. Currently given: 1
+ Usage:
+ generate-int8-scale-table encoder.param encoder.bin decoder.param decoder.bin joiner.param joiner.bin encoder-scale-table.txt joiner-scale-table.txt wave_filenames.txt
+
+ Each line in wave_filenames.txt is a path to some 16k Hz mono wave file.
+
+We need to create a file ``wave_filenames.txt``, in which we need to put
+some calibration wave files. For testing purpose, we put the ``test_wavs``
+from the pre-trained model repository
+``_
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ cat < wave_filenames.txt
+ ../test_wavs/1089-134686-0001.wav
+ ../test_wavs/1221-135766-0001.wav
+ ../test_wavs/1221-135766-0002.wav
+ EOF
+
+Now we can calculate the scales needed for quantization with the calibration data:
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ generate-int8-scale-table \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ./encoder-scale-table.txt \
+ ./joiner-scale-table.txt \
+ ./wave_filenames.txt
+
+The output logs are in the following:
+
+.. literalinclude:: ./code/generate-int-8-scale-table-for-lstm.txt
+
+It generates the following two files:
+
+.. code-block:: bash
+
+ ls -lh encoder-scale-table.txt joiner-scale-table.txt
+
+ -rw-r--r-- 1 kuangfangjun root 345K Feb 17 12:13 encoder-scale-table.txt
+ -rw-r--r-- 1 kuangfangjun root 17K Feb 17 12:13 joiner-scale-table.txt
+
+.. caution::
+
+ Definitely, you need more calibration data to compute the scale table.
+
+Finally, let us use the scale table to quantize our models into ``int8``.
+
+.. code-block:: bash
+
+ ncnn2int8
+
+ usage: ncnn2int8 [inparam] [inbin] [outparam] [outbin] [calibration table]
+
+First, we quantize the encoder model:
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ ncnn2int8 \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./encoder_jit_trace-pnnx.ncnn.int8.param \
+ ./encoder_jit_trace-pnnx.ncnn.int8.bin \
+ ./encoder-scale-table.txt
+
+Next, we quantize the joiner model:
+
+.. code-block:: bash
+
+ ncnn2int8 \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.int8.param \
+ ./joiner_jit_trace-pnnx.ncnn.int8.bin \
+ ./joiner-scale-table.txt
+
+The above two commands generate the following 4 files:
+
+.. code-block::
+
+ -rw-r--r-- 1 kuangfangjun root 218M Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.bin
+ -rw-r--r-- 1 kuangfangjun root 21K Feb 17 12:19 encoder_jit_trace-pnnx.ncnn.int8.param
+ -rw-r--r-- 1 kuangfangjun root 774K Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.bin
+ -rw-r--r-- 1 kuangfangjun root 496 Feb 17 12:19 joiner_jit_trace-pnnx.ncnn.int8.param
+
+Congratulations! You have successfully quantized your model from ``float32`` to ``int8``.
+
+.. caution::
+
+ ``ncnn.int8.param`` and ``ncnn.int8.bin`` must be used in pairs.
+
+ You can replace ``ncnn.param`` and ``ncnn.bin`` with ``ncnn.int8.param``
+ and ``ncnn.int8.bin`` in `sherpa-ncnn`_ if you like.
+
+ For instance, to use only the ``int8`` encoder in ``sherpa-ncnn``, you can
+ replace the following invocation:
+
+ .. code-block::
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp/
+
+ sherpa-ncnn \
+ ../data/lang_bpe_500/tokens.txt \
+ ./encoder_jit_trace-pnnx.ncnn.param \
+ ./encoder_jit_trace-pnnx.ncnn.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ../test_wavs/1089-134686-0001.wav
+
+ with
+
+ .. code-block:: bash
+
+ cd egs/librispeech/ASR
+ cd icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/
+
+ sherpa-ncnn \
+ ../data/lang_bpe_500/tokens.txt \
+ ./encoder_jit_trace-pnnx.ncnn.int8.param \
+ ./encoder_jit_trace-pnnx.ncnn.int8.bin \
+ ./decoder_jit_trace-pnnx.ncnn.param \
+ ./decoder_jit_trace-pnnx.ncnn.bin \
+ ./joiner_jit_trace-pnnx.ncnn.param \
+ ./joiner_jit_trace-pnnx.ncnn.bin \
+ ../test_wavs/1089-134686-0001.wav
+
+The following table compares again the file sizes:
+
++----------------------------------------+------------+
+| File name | File size |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.pt | 318 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1010 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp16) | 159 MB |
++----------------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin (fp16) | 503 KB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp16) | 1.5 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin (fp32) | 317 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin (fp32) | 3.0 MB |
++----------------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.int8.bin | 218 MB |
++----------------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.int8.bin | 774 KB |
++----------------------------------------+------------+
+
+You can see that the file size of the joiner model after ``int8`` quantization
+is much smaller. However, the size of the encoder model is even larger than
+the ``fp16`` counterpart. The reason is that `ncnn`_ currently does not support
+quantizing ``LSTM`` layers into ``8-bit``. Please see
+``_
+
+.. hint::
+
+ Currently, only linear layers and convolutional layers are quantized
+ with ``int8``, so you don't see an exact ``4x`` reduction in file sizes.
+
+.. note::
+
+ You need to test the recognition accuracy after ``int8`` quantization.
+
+
+That's it! Have fun with `sherpa-ncnn`_!
diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst
new file mode 100644
index 000000000..5c81d25ca
--- /dev/null
+++ b/docs/source/model-export/export-ncnn-zipformer.rst
@@ -0,0 +1,383 @@
+.. _export_streaming_zipformer_transducer_models_to_ncnn:
+
+Export streaming Zipformer transducer models to ncnn
+----------------------------------------------------
+
+We use the pre-trained model from the following repository as an example:
+
+``_
+
+We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_.
+
+.. hint::
+
+ We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing.
+
+.. caution::
+
+ Please use a more recent version of PyTorch. For instance, ``torch 1.8``
+ may ``not`` work.
+
+1. Download the pre-trained model
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. hint::
+
+ You have to install `git-lfs`_ before you continue.
+
+
+.. code-block:: bash
+
+ cd egs/librispeech/ASR
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+
+ git lfs pull --include "exp/pretrained.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+
+ cd ..
+
+.. note::
+
+ We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``.
+
+In the above code, we downloaded the pre-trained model into the directory
+``egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29``.
+
+2. Install ncnn and pnnx
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` .
+
+
+3. Export the model via torch.jit.trace()
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+First, let us rename our pre-trained model:
+
+.. code-block::
+
+ cd egs/librispeech/ASR
+
+ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
+
+ ln -s pretrained.pt epoch-99.pt
+
+ cd ../..
+
+Next, we use the following code to export our model:
+
+.. code-block:: bash
+
+ dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+
+ ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
+ --bpe-model $dir/data/lang_bpe_500/bpe.model \
+ --exp-dir $dir/exp \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ \
+ --decode-chunk-len 32 \
+ --num-left-chunks 4 \
+ --num-encoder-layers "2,4,3,2,4" \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --nhead "8,8,8,8,8" \
+ --encoder-dims "384,384,384,384,384" \
+ --attention-dims "192,192,192,192,192" \
+ --encoder-unmasked-dims "256,256,256,256,256" \
+ --zipformer-downsampling-factors "1,2,4,8,2" \
+ --cnn-module-kernels "31,31,31,31,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512
+
+.. caution::
+
+ If your model has different configuration parameters, please change them accordingly.
+
+.. hint::
+
+ We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``.
+ There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``.
+
+ If you have trained a model by yourself and if you have all checkpoints
+ available, please first use ``decode.py`` to tune ``--epoch --avg``
+ and select the best combination with with ``--use-averaged-model 1``.
+
+.. note::
+
+ You will see the following log output:
+
+ .. literalinclude:: ./code/export-zipformer-transducer-for-ncnn-output.txt
+
+ The log shows the model has ``69920376`` parameters, i.e., ``~69.9 M``.
+
+ .. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt
+ -rw-r--r-- 1 kuangfangjun root 269M Jan 12 12:53 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt
+
+ You can see that the file size of the pre-trained model is ``269 MB``, which
+ is roughly equal to ``69920376*4/1024/1024 = 266.725 MB``.
+
+After running ``pruned_transducer_stateless7_streaming/export-for-ncnn.py``,
+we will get the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*pnnx.pt
+
+ -rw-r--r-- 1 kuangfangjun root 1022K Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 266M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt
+ -rw-r--r-- 1 kuangfangjun root 2.8M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt
+
+.. _zipformer-transducer-step-4-export-torchscript-model-via-pnnx:
+
+4. Export torchscript model via pnnx
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. hint::
+
+ Make sure you have set up the ``PATH`` environment variable
+ in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise,
+ it will throw an error saying that ``pnnx`` could not be found.
+
+Now, it's time to export our models to `ncnn`_ via ``pnnx``.
+
+.. code-block::
+
+ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
+
+ pnnx ./encoder_jit_trace-pnnx.pt
+ pnnx ./decoder_jit_trace-pnnx.pt
+ pnnx ./joiner_jit_trace-pnnx.pt
+
+It will generate the following files:
+
+.. code-block:: bash
+
+ ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*ncnn*{bin,param}
+
+ -rw-r--r-- 1 kuangfangjun root 509K Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 437 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 133M Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 152K Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param
+ -rw-r--r-- 1 kuangfangjun root 1.4M Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin
+ -rw-r--r-- 1 kuangfangjun root 488 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param
+
+There are two types of files:
+
+- ``param``: It is a text file containing the model architectures. You can
+ use a text editor to view its content.
+- ``bin``: It is a binary file containing the model parameters.
+
+We compare the file sizes of the models below before and after converting via ``pnnx``:
+
+.. see https://tableconvert.com/restructuredtext-generator
+
++----------------------------------+------------+
+| File name | File size |
++==================================+============+
+| encoder_jit_trace-pnnx.pt | 266 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.pt | 1022 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.pt | 2.8 MB |
++----------------------------------+------------+
+| encoder_jit_trace-pnnx.ncnn.bin | 133 MB |
++----------------------------------+------------+
+| decoder_jit_trace-pnnx.ncnn.bin | 509 KB |
++----------------------------------+------------+
+| joiner_jit_trace-pnnx.ncnn.bin | 1.4 MB |
++----------------------------------+------------+
+
+You can see that the file sizes of the models after conversion are about one half
+of the models before conversion:
+
+ - encoder: 266 MB vs 133 MB
+ - decoder: 1022 KB vs 509 KB
+ - joiner: 2.8 MB vs 1.4 MB
+
+The reason is that by default ``pnnx`` converts ``float32`` parameters
+to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes
+for ``float16``. Thus, it is ``twice smaller`` after conversion.
+
+.. hint::
+
+ If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx``
+ won't convert ``float32`` to ``float16``.
+
+5. Test the exported models in icefall
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. note::
+
+ We assume you have set up the environment variable ``PYTHONPATH`` when
+ building `ncnn`_.
+
+Now we have successfully converted our pre-trained model to `ncnn`_ format.
+The generated 6 files are what we need. You can use the following code to
+test the converted models:
+
+.. code-block:: bash
+
+ python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \
+ --tokens ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt \
+ --encoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param \
+ --encoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin \
+ --decoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param \
+ --decoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin \
+ --joiner-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param \
+ --joiner-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin \
+ ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav
+
+.. hint::
+
+ `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts
+ only 1 wave file as input.
+
+The output is given below:
+
+.. literalinclude:: ./code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt
+
+Congratulations! You have successfully exported a model from PyTorch to `ncnn`_!
+
+.. _zipformer-modify-the-exported-encoder-for-sherpa-ncnn:
+
+6. Modify the exported encoder for sherpa-ncnn
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to use the exported models in `sherpa-ncnn`_, we have to modify
+``encoder_jit_trace-pnnx.ncnn.param``.
+
+Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``:
+
+.. code-block::
+
+ 7767517
+ 2028 2547
+ Input in0 0 1 in0
+
+**Explanation** of the above three lines:
+
+ 1. ``7767517``, it is a magic number and should not be changed.
+ 2. ``2028 2547``, the first number ``2028`` specifies the number of layers
+ in this file, while ``2547`` specifies the number of intermediate outputs
+ of this file
+ 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0``
+ is the layer name of this layer; ``0`` means this layer has no input;
+ ``1`` means this layer has one output; ``in0`` is the output name of
+ this layer.
+
+We need to add 1 extra line and also increment the number of layers.
+The result looks like below:
+
+.. code-block:: bash
+
+ 7767517
+ 2029 2547
+ SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31
+ Input in0 0 1 in0
+
+**Explanation**
+
+ 1. ``7767517``, it is still the same
+ 2. ``2029 2547``, we have added an extra layer, so we need to update ``2028`` to ``2029``.
+ We don't need to change ``2547`` since the newly added layer has no inputs or outputs.
+ 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31``
+ This line is newly added. Its explanation is given below:
+
+ - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``.
+ - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``.
+ - ``0 0`` means this layer has no inputs or output. Must be ``0 0``
+ - ``0=2``, 0 is the key and 2 is the value. MUST be ``0=2``
+ - ``1=32``, 1 is the key and 32 is the value of the
+ parameter ``--decode-chunk-len`` that you provided when running
+ ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``2=4``, 2 is the key and 4 is the value of the
+ parameter ``--num-left-chunks`` that you provided when running
+ ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``3=7``, 3 is the key and 7 is the value of for the amount of padding
+ used in the Conv2DSubsampling layer. It should be 7 for zipformer
+ if you don't change zipformer.py.
+ - ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute.
+ It is attribute 16 since -23300 - (-23316) = 16.
+ The first element of the array is the length of the array, which is 5 in our case.
+ ``2,4,3,2,4`` is the value of ``--num-encoder-layers``that you provided
+ when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``-23317=5,384,384,384,384,384``, attribute 17.
+ The first element of the array is the length of the array, which is 5 in our case.
+ ``384,384,384,384,384`` is the value of ``--encoder-dims``that you provided
+ when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``-23318=5,192,192,192,192,192``, attribute 18.
+ The first element of the array is the length of the array, which is 5 in our case.
+ ``192,192,192,192,192`` is the value of ``--attention-dims`` that you provided
+ when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``-23319=5,1,2,4,8,2``, attribute 19.
+ The first element of the array is the length of the array, which is 5 in our case.
+ ``1,2,4,8,2`` is the value of ``--zipformer-downsampling-factors`` that you provided
+ when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+ - ``-23320=5,31,31,31,31,31``, attribute 20.
+ The first element of the array is the length of the array, which is 5 in our case.
+ ``31,31,31,31,31`` is the value of ``--cnn-module-kernels`` that you provided
+ when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``.
+
+ For ease of reference, we list the key-value pairs that you need to add
+ in the following table. If your model has a different setting, please
+ change the values for ``SherpaMetaData`` accordingly. Otherwise, you
+ will be ``SAD``.
+
+ +----------+--------------------------------------------+
+ | key | value |
+ +==========+============================================+
+ | 0 | 2 (fixed) |
+ +----------+--------------------------------------------+
+ | 1 | ``-decode-chunk-len`` |
+ +----------+--------------------------------------------+
+ | 2 | ``--num-left-chunks`` |
+ +----------+--------------------------------------------+
+ | 3 | 7 (if you don't change code) |
+ +----------+--------------------------------------------+
+ |-23316 | ``--num-encoder-layer`` |
+ +----------+--------------------------------------------+
+ |-23317 | ``--encoder-dims`` |
+ +----------+--------------------------------------------+
+ |-23318 | ``--attention-dims`` |
+ +----------+--------------------------------------------+
+ |-23319 | ``--zipformer-downsampling-factors`` |
+ +----------+--------------------------------------------+
+ |-23320 | ``--cnn-module-kernels`` |
+ +----------+--------------------------------------------+
+
+ 4. ``Input in0 0 1 in0``. No need to change it.
+
+.. caution::
+
+ When you add a new layer ``SherpaMetaData``, please remember to update the
+ number of layers. In our case, update ``2028`` to ``2029``. Otherwise,
+ you will be SAD later.
+
+.. hint::
+
+ After adding the new layer ``SherpaMetaData``, you cannot use this model
+ with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is
+ supported only in `sherpa-ncnn`_.
+
+.. hint::
+
+ `ncnn`_ is very flexible. You can add new layers to it just by text-editing
+ the ``param`` file! You don't need to change the ``bin`` file.
+
+Now you can use this model in `sherpa-ncnn`_.
+Please refer to the following documentation:
+
+ - Linux/macOS/Windows/arm/aarch64: ``_
+ - ``Android``: ``_
+ - ``iOS``: ``_
+ - Python: ``_
+
+We have a list of pre-trained models that have been exported for `sherpa-ncnn`_:
+
+ - ``_
+
+ You can find more usages there.
diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst
index 3dbb8b514..9eb5f85d2 100644
--- a/docs/source/model-export/export-ncnn.rst
+++ b/docs/source/model-export/export-ncnn.rst
@@ -1,12 +1,37 @@
Export to ncnn
==============
-We support exporting LSTM transducer models to `ncnn `_.
+We support exporting the following models
+to `ncnn `_:
-Please refer to :ref:`export-model-for-ncnn` for details.
+ - `Zipformer transducer models `_
-We also provide ``_
-performing speech recognition using ``ncnn`` with exported models.
-It has been tested on Linux, macOS, Windows, and Raspberry Pi. The project is
-self-contained and can be statically linked to produce a binary containing
-everything needed.
+ - `LSTM transducer models `_
+
+ - `ConvEmformer transducer models `_
+
+We also provide `sherpa-ncnn`_
+for performing speech recognition using `ncnn`_ with exported models.
+It has been tested on the following platforms:
+
+ - Linux
+ - macOS
+ - Windows
+ - ``Android``
+ - ``iOS``
+ - ``Raspberry Pi``
+ - `爱芯派 `_ (`MAIX-III AXera-Pi `_).
+ - `RV1126 `_
+
+`sherpa-ncnn`_ is self-contained and can be statically linked to produce
+a binary containing everything needed. Please refer
+to its documentation for details:
+
+ - ``_
+
+
+.. toctree::
+
+ export-ncnn-zipformer
+ export-ncnn-conv-emformer
+ export-ncnn-lstm
diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst
index dd4b3437a..aa77204cb 100644
--- a/docs/source/model-export/export-onnx.rst
+++ b/docs/source/model-export/export-onnx.rst
@@ -1,69 +1,95 @@
Export to ONNX
==============
-In this section, we describe how to export models to ONNX.
+In this section, we describe how to export models to `ONNX`_.
+
+In each recipe, there is a file called ``export-onnx.py``, which is used
+to export trained models to `ONNX`_.
+
+There is also a file named ``onnx_pretrained.py``, which you can use
+the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files.
+
+sherpa-onnx
+-----------
+
+We have a separate repository `sherpa-onnx`_ for deploying your exported models
+on various platforms such as:
+
+ - iOS
+ - Android
+ - Raspberry Pi
+ - Linux/macOS/Windows
+
+
+Please see the documentation of `sherpa-onnx`_ for details:
+
+ ``_
+
+Example
+-------
+
+In the following, we demonstrate how to export a streaming Zipformer pre-trained
+model from
+``_
+to `ONNX`_.
+
+Download the pre-trained model
+------------------------------
.. hint::
- Only non-streaming conformer transducer models are tested.
-
-
-When to use it
---------------
-
-It you want to use an inference framework that supports ONNX
-to run the pretrained model.
-
-
-How to export
--------------
-
-We use
-``_
-as an example in the following.
+ We assume you have installed `git-lfs`_.
.. code-block:: bash
- cd egs/librispeech/ASR
- epoch=14
- avg=2
- ./pruned_transducer_stateless3/export.py \
- --exp-dir ./pruned_transducer_stateless3/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
- --epoch $epoch \
- --avg $avg \
- --onnx 1
+ cd egs/librispeech/ASR
-It will generate the following files inside ``pruned_transducer_stateless3/exp``:
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
- - ``encoder.onnx``
- - ``decoder.onnx``
- - ``joiner.onnx``
- - ``joiner_encoder_proj.onnx``
- - ``joiner_decoder_proj.onnx``
+ pushd $repo
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/pretrained.pt"
+ cd exp
+ ln -s pretrained.pt epoch-99.pt
+ popd
-You can use ``./pruned_transducer_stateless3/exp/onnx_pretrained.py`` to decode
-waves with the generated files:
+Export the model to ONNX
+------------------------
.. code-block:: bash
- ./pruned_transducer_stateless3/onnx_pretrained.py \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
- --decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
- --joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
- --joiner-encoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_encoder_proj.onnx \
- --joiner-decoder-proj-model-filename ./pruned_transducer_stateless3/exp/joiner_decoder_proj.onnx \
- /path/to/foo.wav \
- /path/to/bar.wav \
- /path/to/baz.wav
+ ./pruned_transducer_stateless7_streaming/export-onnx.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --decode-chunk-len 32 \
+ --exp-dir $repo/exp/
+.. warning::
-How to use the exported model
------------------------------
+ ``export-onnx.py`` from different recipes has different options.
-We also provide ``_
-performing speech recognition using `onnxruntime `_
-with exported models.
-It has been tested on Linux, macOS, and Windows.
+ In the above example, ``--decode-chunk-len`` is specific for the
+ streaming Zipformer. Other models won't have such an option.
+
+It will generate the following 3 files in ``$repo/exp``
+
+ - ``encoder-epoch-99-avg-1.onnx``
+ - ``decoder-epoch-99-avg-1.onnx``
+ - ``joiner-epoch-99-avg-1.onnx``
+
+Decode sound files with exported ONNX models
+--------------------------------------------
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav
diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst
index a041dc1d5..efd7dc2e1 100644
--- a/docs/source/model-export/export-with-torch-jit-script.rst
+++ b/docs/source/model-export/export-with-torch-jit-script.rst
@@ -1,7 +1,7 @@
.. _export-model-with-torch-jit-script:
Export model with torch.jit.script()
-===================================
+====================================
In this section, we describe how to export a model via
``torch.jit.script()``.
diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
similarity index 99%
rename from docs/source/recipes/aishell/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
index 72690e102..6e30ce397 100644
--- a/docs/source/recipes/aishell/conformer_ctc.rst
+++ b/docs/source/recipes/Non-streaming-ASR/aishell/conformer_ctc.rst
@@ -703,7 +703,7 @@ It will show you the following message:
HLG decoding
-^^^^^^^^^^^^
+~~~~~~~~~~~~
.. code-block:: bash
diff --git a/docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-conformer-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-tdnn-lstm-ctc-tensorboard-log.jpg
diff --git a/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst
similarity index 99%
rename from docs/source/recipes/aishell/index.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/index.rst
index d072d6e9c..b77d59bca 100644
--- a/docs/source/recipes/aishell/index.rst
+++ b/docs/source/recipes/Non-streaming-ASR/aishell/index.rst
@@ -19,4 +19,3 @@ It can be downloaded from ``_
tdnn_lstm_ctc
conformer_ctc
stateless_transducer
-
diff --git a/docs/source/recipes/aishell/stateless_transducer.rst b/docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
similarity index 100%
rename from docs/source/recipes/aishell/stateless_transducer.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/stateless_transducer.rst
diff --git a/docs/source/recipes/aishell/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
similarity index 100%
rename from docs/source/recipes/aishell/tdnn_lstm_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/aishell/tdnn_lstm_ctc.rst
diff --git a/docs/source/recipes/Non-streaming-ASR/index.rst b/docs/source/recipes/Non-streaming-ASR/index.rst
new file mode 100644
index 000000000..67123a648
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/index.rst
@@ -0,0 +1,10 @@
+Non Streaming ASR
+=================
+
+.. toctree::
+ :maxdepth: 2
+
+ aishell/index
+ librispeech/index
+ timit/index
+ yesno/index
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
similarity index 99%
rename from docs/source/recipes/librispeech/conformer_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
index 4656acfd6..b7f89c89f 100644
--- a/docs/source/recipes/librispeech/conformer_ctc.rst
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/conformer_ctc.rst
@@ -888,7 +888,7 @@ It will show you the following message:
CTC decoding
-^^^^^^^^^^^^
+~~~~~~~~~~~~
.. code-block:: bash
@@ -926,7 +926,7 @@ Its output is:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
HLG decoding
-^^^^^^^^^^^^
+~~~~~~~~~~~~
.. code-block:: bash
@@ -966,7 +966,7 @@ The output is:
HLG decoding + n-gram LM rescoring
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
@@ -1012,7 +1012,7 @@ The output is:
HLG decoding + n-gram LM rescoring + attention decoder rescoring
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: bash
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst
new file mode 100644
index 000000000..ea9f350cd
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/distillation.rst
@@ -0,0 +1,223 @@
+Distillation with HuBERT
+========================
+
+This tutorial shows you how to perform knowledge distillation in `icefall`_
+with the `LibriSpeech`_ dataset. The distillation method
+used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
+Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation `_
+for more details about MVQ-KD.
+
+.. note::
+
+ This tutorial is based on recipe
+ `pruned_transducer_stateless4 `_.
+ Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
+ with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
+ encounter any problems, please open an issue here `icefall `_.
+
+.. note::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for `icefall`_.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+Data preparation
+----------------
+
+We first prepare necessary training data for `LibriSpeech`_.
+This is the same as in :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
+
+.. hint::
+
+ The data preparation is the same as other recipes on LibriSpeech dataset,
+ if you have finished this step, you can skip to :ref:`codebook_index_preparation` directly.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0 # run only stage 0
+ $ ./prepare.sh --stage 2 --stop-stage 5 # run from stage 2 to stage 5
+
+.. HINT::
+
+ If you have pre-downloaded the `LibriSpeech`_
+ dataset and the `musan`_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. NOTE::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+
+.. _codebook_index_preparation:
+
+Codebook index preparation
+--------------------------
+
+Here, we prepare necessary data for MVQ-KD. This requires the generation
+of codebook indexes (please read our `paper `_.
+if you are interested in details). In this tutorial, we use the pre-computed
+codebook indexes for convenience. The only thing you need to do is to
+run `./distillation_with_hubert.sh `_.
+
+.. note::
+
+ There are 5 stages in total, the first and second stage will be automatically skipped
+ when choosing to downloaded codebook indexes prepared by `icefall`_.
+ Of course, you can extract and compute the codebook indexes by yourself. This
+ will require you downloading a HuBERT-XL model and it can take a while for
+ the extraction of codebook indexes.
+
+
+As usual, you can control the stages you want to run by specifying the following
+two options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./distillation_with_hubert.sh --stage 0 --stop-stage 0 # run only stage 0
+ $ ./distillation_with_hubert.sh --stage 2 --stop-stage 4 # run from stage 2 to stage 5
+
+Here are a few options in `./distillation_with_hubert.sh `_
+you need to know before you proceed.
+
+- ``--full_libri`` If True, use full 960h data. Otherwise only ``train-clean-100`` will be used
+- ``--use_extracted_codebook`` If True, the first two stages will be skipped and the codebook
+ indexes uploaded by us will be downloaded.
+
+Since we are using the pre-computed codebook indexes, we set
+``use_extracted_codebook=True``. If you want to do full `LibriSpeech`_
+experiments, please set ``full_libri=True``.
+
+The following command downloads the pre-computed codebook indexes
+and prepares MVQ-augmented training manifests.
+
+.. code-block:: bash
+
+ $ ./distillation_with_hubert.sh --stage 2 --stop-stage 2 # run only stage 2
+
+Please see the
+following screenshot for the output of an example execution.
+
+.. figure:: ./images/distillation_codebook.png
+ :width: 800
+ :alt: Downloading codebook indexes and preparing training manifest.
+ :align: center
+
+ Downloading codebook indexes and preparing training manifest.
+
+.. hint::
+
+ The codebook indexes we prepared for you in this tutorial
+ are extracted from the 36-th layer of a fine-tuned HuBERT-XL model
+ with 8 codebooks. If you want to try other configurations, please
+ set ``use_extracted_codebook=False`` and set ``embedding_layer`` and
+ ``num_codebooks`` by yourself.
+
+Now, you should see the following files under the directory ``./data/vq_fbank_layer36_cb8``.
+
+.. figure:: ./images/distillation_directory.png
+ :width: 800
+ :alt: MVQ-augmented training manifests
+ :align: center
+
+ MVQ-augmented training manifests.
+
+Whola! You are ready to perform knowledge distillation training now!
+
+Training
+--------
+
+To perform training, please run stage 3 by executing the following command.
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 3 --stop-stage 3 # run MVQ training
+
+Here is the code snippet for training:
+
+.. code-block:: bash
+
+ WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}')
+
+ ./pruned_transducer_stateless6/train.py \
+ --manifest-dir ./data/vq_fbank_layer36_cb8 \
+ --master-port 12359 \
+ --full-libri $full_libri \
+ --spec-aug-time-warp-factor -1 \
+ --max-duration 300 \
+ --world-size ${WORLD_SIZE} \
+ --num-epochs 30 \
+ --exp-dir $exp_dir \
+ --enable-distillation True \
+ --codebook-loss-scale 0.01
+
+There are a few training arguments in the following
+training commands that should be paid attention to.
+
+ - ``--enable-distillation`` If True, knowledge distillation training is enabled.
+ - ``--codebook-loss-scale`` The scale of the knowledge distillation loss.
+ - ``--manifest-dir`` The path to the MVQ-augmented manifest.
+
+
+Decoding
+--------
+
+After training finished, you can test the performance on using
+the following command.
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES=0
+ ./pruned_transducer_stateless6/train.py \
+ --decoding-method "modified_beam_search" \
+ --epoch 30 \
+ --avg 10 \
+ --max-duration 200 \
+ --exp-dir $exp_dir \
+ --enable-distillation True
+
+You should get similar results as `here `_.
+
+That's all! Feel free to experiment with your own setups and report your results.
+If you encounter any problems during training, please open up an issue `here `_.
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png
new file mode 100644
index 000000000..1a40d6c6e
Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_codebook.png differ
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png
new file mode 100644
index 000000000..30763046f
Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/distillation_directory.png differ
diff --git a/docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-conformer-ctc-tensorboard-log.png
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg
new file mode 100644
index 000000000..800835749
Binary files /dev/null and b/docs/source/recipes/Non-streaming-ASR/librispeech/images/librispeech-pruned-transducer-tensorboard-log.jpg differ
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst
new file mode 100644
index 000000000..bf439861a
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/index.rst
@@ -0,0 +1,12 @@
+LibriSpeech
+===========
+
+.. toctree::
+ :maxdepth: 1
+
+ tdnn_lstm_ctc
+ conformer_ctc
+ pruned_transducer_stateless
+ zipformer_mmi
+ zipformer_ctc_blankskip
+ distillation
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst
new file mode 100644
index 000000000..42fd3df77
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/pruned_transducer_stateless.rst
@@ -0,0 +1,548 @@
+.. _non_streaming_librispeech_pruned_transducer_stateless:
+
+Pruned transducer statelessX
+============================
+
+This tutorial shows you how to run a conformer transducer model
+with the `LibriSpeech `_ dataset.
+
+.. Note::
+
+ The tutorial is suitable for `pruned_transducer_stateless `_,
+ `pruned_transducer_stateless2 `_,
+ `pruned_transducer_stateless4 `_,
+ `pruned_transducer_stateless5 `_,
+ We will take pruned_transducer_stateless4 as an example in this tutorial.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+.. hint::
+
+ Please scroll down to the bottom of this page to find download links
+ for pretrained models if you don't want to train a model from scratch.
+
+
+We use pruned RNN-T to compute the loss.
+
+.. note::
+
+ You can find the paper about pruned RNN-T at the following address:
+
+ ``_
+
+The transducer model consists of 3 parts:
+
+ - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey)
+ - Decoder, a.k.a, the prediction network. We use a stateless model consisting of
+ ``nn.Embedding`` and ``nn.Conv1d``
+ - Joiner, a.k.a, the joint network.
+
+.. caution::
+
+ Contrary to the conventional RNN-T models, we use a stateless decoder.
+ That is, it has no recurrent connections.
+
+
+Data preparation
+----------------
+
+.. hint::
+
+ The data preparation is the same as other recipes on LibriSpeech dataset,
+ if you have finished this step, you can skip to ``Training`` directly.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. HINT::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. NOTE::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+
+Training
+--------
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/train.py --help
+
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--exp-dir``
+
+ The directory to save checkpoints, training logs and tensorboard.
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+ in the folder ``./pruned_transducer_stateless4/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the
+ checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./pruned_transducer_stateless4/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./pruned_transducer_stateless4/train.py --world-size 1
+
+ .. caution::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+ - ``--use-fp16``
+
+ If it is True, the model will train with half precision, from our experiment
+ results, by using half precision you can train with two times larger ``--max-duration``
+ so as to get almost 2X speed up.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., number of encoder layers,
+encoder dimension, decoder dimension, number of warmup steps etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`pruned_transducer_stateless4/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
+
+
+.. NOTE::
+
+ The options for `pruned_transducer_stateless5 `_ are a little different from
+ other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
+
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+ These are checkpoint files saved at the end of each epoch, containing model
+ ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless4/train.py --start-epoch 11
+
+ - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+ These are checkpoint files saved every ``--save-every-n`` batches,
+ containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless4/train.py --start-batch 436000
+
+ - ``tensorboard/``
+
+ This folder contains tensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd pruned_transducer_stateless4/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/
+
+ [2022-11-20T15:50:50] Started scanning logdir.
+ Uploading 4468 scalars...
+ [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
+ Listening for new data in logdir...
+
+ Note there is a URL in the above output. Click it and you will see
+ the following screenshot:
+
+ .. figure:: images/librispeech-pruned-transducer-tensorboard-log.jpg
+ :width: 600
+ :alt: TensorBoard screenshot
+ :align: center
+ :target: https://tensorboard.dev/experiment/QOGSPBgsR8KzcRMmie9JGw/
+
+ TensorBoard screenshot.
+
+ .. hint::
+
+ If you don't have access to google, you can use the following command
+ to view the tensorboard log locally:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless4/exp/tensorboard
+ tensorboard --logdir . --port 6008
+
+ It will print the following message:
+
+ .. code-block::
+
+ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+ TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+ Now start your browser and go to ``_ to view the tensorboard
+ logs.
+
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 6 GPUs:
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
+ ./pruned_transducer_stateless4/train.py \
+ --world-size 6 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+ There are two kinds of checkpoints:
+
+ - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+ of each epoch. You can pass ``--epoch`` to
+ ``pruned_transducer_stateless4/decode.py`` to use them.
+
+ - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+ every ``--save-every-n`` batches. You can pass ``--iter`` to
+ ``pruned_transducer_stateless4/decode.py`` to use them.
+
+ We suggest that you try both types of checkpoints and choose the one
+ that produces the lowest WERs.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/decode.py --help
+
+shows the options for decoding.
+
+The following shows two examples (for two types of checkpoints):
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 25 20; do
+ for avg in 7 5 3 1; do
+ ./pruned_transducer_stateless4/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for iter in 474000; do
+ for avg in 8 10 12 14 16 18; do
+ ./pruned_transducer_stateless4/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. Note::
+
+ Supporting decoding methods are as follows:
+
+ - ``greedy_search`` : It takes the symbol with largest posterior probability
+ of each frame as the decoding result.
+
+ - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
+ `espnet/nets/beam_search_transducer.py `_
+ is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
+ next frame.
+
+ - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
+ runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded.
+
+ - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and
+ given ``FSAs``. It is hard to describe the details in several lines of texts, you can read
+ our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently.
+
+ - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses
+ an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph
+ (with N-gram LM).
+
+ - ``fast_beam_search_nbest`` : It produces the decoding results as follows:
+
+ - (1) Use ``fast_beam_search`` to get a lattice
+ - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()``
+ - (3) Unique the selected paths
+ - (4) Intersect the selected paths with the lattice and compute the
+ shortest path from the intersection result
+ - (5) The path with the largest score is used as the decoding output.
+
+ - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the
+ only difference is that it uses ``fast_beam_search_LG`` to generate the lattice.
+
+
+Export Model
+------------
+
+`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+ # Assume that --epoch 25 --avg 3 produces the smallest WER
+ # (You can get such information after running ./pruned_transducer_stateless4/decode.py)
+
+ epoch=25
+ avg=3
+
+ ./pruned_transducer_stateless4/export.py \
+ --exp-dir ./pruned_transducer_stateless4/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch $epoch \
+ --avg $avg
+
+It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``.
+
+.. hint::
+
+ To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``,
+ you can run:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless4/exp
+ ln -s pretrained.pt epoch-999.pt
+
+ And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to
+ ``./pruned_transducer_stateless4/decode.py``.
+
+To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless4/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless4/export.py \
+ --exp-dir ./pruned_transducer_stateless4/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 25 \
+ --avg 3 \
+ --jit 1
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+.. NOTE::
+
+ You will need this ``cpu_jit.pt`` when deploying with Sherpa framework.
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+ - `pruned_transducer_stateless `_
+
+ - `pruned_transducer_stateless2 `_
+
+ - `pruned_transducer_stateless4 `_
+
+ - `pruned_transducer_stateless5 `_
+
+ See ``_
+ for the details of the above pretrained models
+
+
+Deploy with Sherpa
+------------------
+
+Please see ``_
+for how to deploy the models in ``sherpa``.
diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst
similarity index 100%
rename from docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/librispeech/tdnn_lstm_ctc.rst
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst
new file mode 100644
index 000000000..aa73bfe33
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_ctc_blankskip.rst
@@ -0,0 +1,454 @@
+Zipformer CTC Blank Skip
+========================
+
+.. hint::
+
+ Please scroll down to the bottom of this page to find download links
+ for pretrained models if you don't want to train a model from scratch.
+
+
+This tutorial shows you how to train a Zipformer model based on the guidance from
+a co-trained CTC model using `blank skip method `_
+with the `LibriSpeech `_ dataset.
+
+.. note::
+
+ We use both CTC and RNN-T loss to train. During the forward pass, the encoder output
+ is first used to calculate the CTC posterior probability; then for each output frame,
+ if its blank posterior is bigger than some threshold, it will be simply discarded
+ from the encoder output. To prevent information loss, we also put a convolution module
+ similar to the one used in conformer (referred to as “LConv”) before the frame reduction.
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+.. note::
+
+ We encourage you to read ``./prepare.sh``.
+
+The data preparation contains several stages. You can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. hint::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. note::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+Training
+--------
+
+For stability, it doesn`t use blank skip method until model warm-up.
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./pruned_transducer_stateless7_ctc_bs/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+ in the folder ``./pruned_transducer_stateless7_ctc_bs/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 10`` loads the
+ checkpoint ``./pruned_transducer_stateless7_ctc_bs/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --world-size 1
+
+ .. caution::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., weight decay,
+number of warmup steps, results dir, etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`pruned_transducer_stateless7_ctc_bs/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./pruned_transducer_stateless7_ctc_bs/train.py`` directly.
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``pruned_transducer_stateless7_ctc_bs/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+ These are checkpoint files saved at the end of each epoch, containing model
+ ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-epoch 11
+
+ - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+ These are checkpoint files saved every ``--save-every-n`` batches,
+ containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless7_ctc_bs/train.py --start-batch 436000
+
+ - ``tensorboard/``
+
+ This folder contains tensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "Zipformer-CTC co-training using blank skip for LibriSpeech with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/
+
+ Note there is a URL in the above output. Click it and you will see
+ tensorboard.
+
+ .. hint::
+
+ If you don't have access to google, you can use the following command
+ to view the tensorboard log locally:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless7_ctc_bs/exp/tensorboard
+ tensorboard --logdir . --port 6008
+
+ It will print the following message:
+
+ .. code-block::
+
+ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+ TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+ Now start your browser and go to ``_ to view the tensorboard
+ logs.
+
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 4 GPUs:
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ ./pruned_transducer_stateless7_ctc_bs/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --full-libri 1 \
+ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
+ --max-duration 600 \
+ --use-fp16 1
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+ There are two kinds of checkpoints:
+
+ - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+ of each epoch. You can pass ``--epoch`` to
+ ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
+
+ - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+ every ``--save-every-n`` batches. You can pass ``--iter`` to
+ ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
+
+ We suggest that you try both types of checkpoints and choose the one
+ that produces the lowest WERs.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help
+
+shows the options for decoding.
+
+The following shows the example using ``epoch-*.pt``:
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
+ --epoch 30 \
+ --avg 13 \
+ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+
+To test CTC branch, you can use the following command:
+
+.. code-block:: bash
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
+ --epoch 30 \
+ --avg 13 \
+ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+
+Export models
+-------------
+
+`pruned_transducer_stateless7_ctc_bs/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless7_ctc_bs/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``pruned_transducer_stateless7_ctc_bs/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/export.py \
+ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 13 \
+ --jit 0
+
+It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt``.
+
+.. hint::
+
+ To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``,
+ you can run:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless7_ctc_bs/exp
+ ln -s pretrained epoch-9999.pt
+
+ And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
+ ``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``.
+
+To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+To test CTC branch using the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py``:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
+ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --method ctc-decoding \
+ --sample-rate 16000 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/export.py \
+ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 13 \
+ --jit 1
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+To use the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py``:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \
+ --nn-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+To test CTC branch using the generated files with ``./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py``:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
+ --model-filename ./pruned_transducer_stateless7_ctc_bs/exp/cpu_jit.pt \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --method ctc-decoding \
+ --sample-rate 16000 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+ - trained on LibriSpeech 100h: ``_
+ - trained on LibriSpeech 960h: ``_
+
+ See ``_
+ for the details of the above pretrained models
diff --git a/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst
new file mode 100644
index 000000000..a7b59a992
--- /dev/null
+++ b/docs/source/recipes/Non-streaming-ASR/librispeech/zipformer_mmi.rst
@@ -0,0 +1,422 @@
+Zipformer MMI
+===============
+
+.. hint::
+
+ Please scroll down to the bottom of this page to find download links
+ for pretrained models if you don't want to train a model from scratch.
+
+
+This tutorial shows you how to train an Zipformer MMI model
+with the `LibriSpeech `_ dataset.
+
+We use LF-MMI to compute the loss.
+
+.. note::
+
+ You can find the document about LF-MMI training at the following address:
+
+ ``_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+.. note::
+
+ We encourage you to read ``./prepare.sh``.
+
+The data preparation contains several stages. You can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. hint::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. note::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+Training
+--------
+
+For stability, it uses CTC loss for model warm-up and then switches to MMI loss.
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./zipformer_mmi/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./zipformer_mmi/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+ in the folder ``./zipformer_mmi/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./zipformer_mmi/train.py --start-epoch 10`` loads the
+ checkpoint ``./zipformer_mmi/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./zipformer_mmi/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./zipformer_mmi/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./zipformer_mmi/train.py --world-size 1
+
+ .. caution::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., weight decay,
+number of warmup steps, results dir, etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`zipformer_mmi/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./zipformer_mmi/train.py`` directly.
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``zipformer_mmi/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+ These are checkpoint files saved at the end of each epoch, containing model
+ ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./zipformer_mmi/train.py --start-epoch 11
+
+ - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+ These are checkpoint files saved every ``--save-every-n`` batches,
+ containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+ .. code-block:: bash
+
+ $ ./zipformer_mmi/train.py --start-batch 436000
+
+ - ``tensorboard/``
+
+ This folder contains tensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd zipformer_mmi/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "Zipformer MMI training for LibriSpeech with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/xyOZUKpEQm62HBIlUD4uPA/
+
+ Note there is a URL in the above output. Click it and you will see
+ tensorboard.
+
+ .. hint::
+
+ If you don't have access to google, you can use the following command
+ to view the tensorboard log locally:
+
+ .. code-block:: bash
+
+ cd zipformer_mmi/exp/tensorboard
+ tensorboard --logdir . --port 6008
+
+ It will print the following message:
+
+ .. code-block::
+
+ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+ TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+ Now start your browser and go to ``_ to view the tensorboard
+ logs.
+
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 4 GPUs:
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ ./zipformer_mmi/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --full-libri 1 \
+ --exp-dir zipformer_mmi/exp \
+ --max-duration 500 \
+ --use-fp16 1 \
+ --num-workers 2
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+ There are two kinds of checkpoints:
+
+ - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+ of each epoch. You can pass ``--epoch`` to
+ ``zipformer_mmi/decode.py`` to use them.
+
+ - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+ every ``--save-every-n`` batches. You can pass ``--iter`` to
+ ``zipformer_mmi/decode.py`` to use them.
+
+ We suggest that you try both types of checkpoints and choose the one
+ that produces the lowest WERs.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./zipformer_mmi/decode.py --help
+
+shows the options for decoding.
+
+The following shows the example using ``epoch-*.pt``:
+
+.. code-block:: bash
+
+ for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+ ./zipformer_mmi/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir ./zipformer_mmi/exp/ \
+ --max-duration 100 \
+ --lang-dir data/lang_bpe_500 \
+ --nbest-scale 1.2 \
+ --hp-scale 1.0 \
+ --decoding-method $m
+ done
+
+
+Export models
+-------------
+
+`zipformer_mmi/export.py `_ supports exporting checkpoints from ``zipformer_mmi/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``zipformer_mmi/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+ ./zipformer_mmi/export.py \
+ --exp-dir ./zipformer_mmi/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 0
+
+It will generate a file ``./zipformer_mmi/exp/pretrained.pt``.
+
+.. hint::
+
+ To use the generated ``pretrained.pt`` for ``zipformer_mmi/decode.py``,
+ you can run:
+
+ .. code-block:: bash
+
+ cd zipformer_mmi/exp
+ ln -s pretrained epoch-9999.pt
+
+ And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
+ ``./zipformer_mmi/decode.py``.
+
+To use the exported model with ``./zipformer_mmi/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+ ./zipformer_mmi/pretrained.py \
+ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method 1best \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ ./zipformer_mmi/export.py \
+ --exp-dir ./zipformer_mmi/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+To use the generated files with ``./zipformer_mmi/jit_pretrained.py``:
+
+.. code-block:: bash
+
+ ./zipformer_mmi/jit_pretrained.py \
+ --nn-model-filename ./zipformer_mmi/exp/cpu_jit.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method 1best \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+ - ``_
+
+ See ``_
+ for the details of the above pretrained models
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/Non-streaming-ASR/timit/index.rst
similarity index 98%
rename from docs/source/recipes/timit/index.rst
rename to docs/source/recipes/Non-streaming-ASR/timit/index.rst
index 17f40cdb7..5ee147be7 100644
--- a/docs/source/recipes/timit/index.rst
+++ b/docs/source/recipes/Non-streaming-ASR/timit/index.rst
@@ -6,4 +6,3 @@ TIMIT
tdnn_ligru_ctc
tdnn_lstm_ctc
-
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst
similarity index 97%
rename from docs/source/recipes/timit/tdnn_ligru_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst
index 186420ee7..3d7aefe02 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_ligru_ctc.rst
@@ -148,10 +148,10 @@ Some commonly used options are:
$ ./tdnn_ligru_ctc/decode.py --epoch 25 --avg 17
- uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
- ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
- ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
- ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
+ uses the average of ``epoch-9.pt``, ``epoch-10.pt``, ``epoch-11.pt``,
+ ``epoch-12.pt``, ``epoch-13.pt``, ``epoch-14.pt``, ``epoch-15.pt``,
+ ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``, ``epoch-19.pt``,
+ ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``, ``epoch-23.pt``,
``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
@@ -317,13 +317,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash
- ./tdnn_ligru_ctc/pretrained.py
+ ./tdnn_ligru_ctc/pretrained.py
--method 1best
- --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
- --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
- --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
- ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
- ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+ --checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
+ --words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
+ --HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
+ ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+ ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
@@ -337,7 +337,7 @@ The output is:
2021-11-08 20:41:38,697 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:41:38,704 INFO [pretrained.py:216] Decoding started
2021-11-08 20:41:39,819 INFO [pretrained.py:246] Use HLG decoding
- 2021-11-08 20:41:39,829 INFO [pretrained.py:267]
+ 2021-11-08 20:41:39,829 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx ih w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
@@ -362,8 +362,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt \
--G ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.1 \
- ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
- ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
+ ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
+ ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is:
@@ -378,7 +378,7 @@ The decoding output is:
2021-11-08 20:37:54,715 INFO [pretrained.py:210] Reading sound files: ['./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FELC0_SI756.WAV', './tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:37:54,720 INFO [pretrained.py:216] Decoding started
2021-11-08 20:37:55,808 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
- 2021-11-08 20:37:56,348 INFO [pretrained.py:267]
+ 2021-11-08 20:37:56,348 INFO [pretrained.py:267]
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ah sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy ng w ih th ih n ih m s eh l f sil jh
diff --git a/docs/source/recipes/timit/tdnn_lstm_ctc.rst b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst
similarity index 97%
rename from docs/source/recipes/timit/tdnn_lstm_ctc.rst
rename to docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst
index 6f760a9ce..ee67a6edc 100644
--- a/docs/source/recipes/timit/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/Non-streaming-ASR/timit/tdnn_lstm_ctc.rst
@@ -148,8 +148,8 @@ Some commonly used options are:
$ ./tdnn_lstm_ctc/decode.py --epoch 25 --avg 10
- uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
- ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
+ uses the average of ``epoch-16.pt``, ``epoch-17.pt``, ``epoch-18.pt``,
+ ``epoch-19.pt``, ``epoch-20.pt``, ``epoch-21.pt``, ``epoch-22.pt``,
``epoch-23.pt``, ``epoch-24.pt`` and ``epoch-25.pt``
for decoding.
@@ -315,13 +315,13 @@ To decode with ``1best`` method, we can use:
.. code-block:: bash
- ./tdnn_lstm_ctc/pretrained.py
+ ./tdnn_lstm_ctc/pretrained.py
--method 1best
- --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
- --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
- --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
- ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
- ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+ --checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
+ --words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
+ --HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
+ ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+ ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The output is:
@@ -335,7 +335,7 @@ The output is:
2021-11-08 21:02:53,827 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 21:02:53,831 INFO [pretrained.py:216] Decoding started
2021-11-08 21:02:54,380 INFO [pretrained.py:246] Use HLG decoding
- 2021-11-08 21:02:54,387 INFO [pretrained.py:267]
+ 2021-11-08 21:02:54,387 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw ah l iy v iy z ih sil p r aa sil k s ih m ey dx ih sil d w uh dx iy w ih s f iy l iy w ih th ih n ih m s eh l f sil jh
@@ -360,8 +360,8 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
--G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.08 \
- ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
- ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
+ ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
+ ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
The decoding output is:
@@ -376,7 +376,7 @@ The decoding output is:
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
- 2021-11-08 20:05:27,878 INFO [pretrained.py:267]
+ 2021-11-08 20:05:27,878 INFO [pretrained.py:267]
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
diff --git a/docs/source/recipes/yesno/images/tdnn-tensorboard-log.png b/docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/yesno/images/tdnn-tensorboard-log.png
rename to docs/source/recipes/Non-streaming-ASR/yesno/images/tdnn-tensorboard-log.png
diff --git a/docs/source/recipes/yesno/index.rst b/docs/source/recipes/Non-streaming-ASR/yesno/index.rst
similarity index 100%
rename from docs/source/recipes/yesno/index.rst
rename to docs/source/recipes/Non-streaming-ASR/yesno/index.rst
diff --git a/docs/source/recipes/yesno/tdnn.rst b/docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst
similarity index 100%
rename from docs/source/recipes/yesno/tdnn.rst
rename to docs/source/recipes/Non-streaming-ASR/yesno/tdnn.rst
diff --git a/docs/source/recipes/Streaming-ASR/index.rst b/docs/source/recipes/Streaming-ASR/index.rst
new file mode 100644
index 000000000..8c0ffe447
--- /dev/null
+++ b/docs/source/recipes/Streaming-ASR/index.rst
@@ -0,0 +1,12 @@
+Streaming ASR
+=============
+
+.. toctree::
+ :maxdepth: 1
+
+ introduction
+
+.. toctree::
+ :maxdepth: 2
+
+ librispeech/index
diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst
new file mode 100644
index 000000000..e1382e77d
--- /dev/null
+++ b/docs/source/recipes/Streaming-ASR/introduction.rst
@@ -0,0 +1,53 @@
+Introduction
+============
+
+This page shows you how we implement streaming **X-former transducer** models for ASR.
+
+.. HINT::
+ X-former transducer here means the encoder of the transducer model uses Multi-Head Attention,
+ like `Conformer `_, `EmFormer `_ etc.
+
+Currently we have implemented two types of streaming models, one uses Conformer as encoder, the other uses Emformer as encoder.
+
+Streaming Conformer
+-------------------
+
+The main idea of training a streaming model is to make the model see limited contexts
+in training time, we can achieve this by applying a mask to the output of self-attention.
+In icefall, we implement the streaming conformer the way just like what `WeNet `_ did.
+
+.. NOTE::
+ The conformer-transducer recipes in LibriSpeech datasets, like, `pruned_transducer_stateless `_,
+ `pruned_transducer_stateless2 `_,
+ `pruned_transducer_stateless3 `_,
+ `pruned_transducer_stateless4 `_,
+ `pruned_transducer_stateless5 `_
+ all support streaming.
+
+.. NOTE::
+ Training a streaming conformer model in ``icefall`` is almost the same as training a
+ non-streaming model, all you need to do is passing several extra arguments.
+ See :doc:`Pruned transducer statelessX ` for more details.
+
+.. HINT::
+ If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer
+ to `this pull request `_. After adding the code needed by streaming training,
+ you have to re-train it with the extra arguments metioned in the docs above to get a streaming model.
+
+
+Streaming Emformer
+------------------
+
+The Emformer model proposed `here `_ uses more
+complicated techniques. It has a memory bank component to memorize history information,
+what' more, it also introduces right context in training time by hard-copying part of
+the input features.
+
+We have three variants of Emformer models in ``icefall``.
+
+ - ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe `_.
+ - ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
+ ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
+ See `LibriSpeech recipe `_.
+ - ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
+ it uses a simplified memory bank. See `LibriSpeech recipe `_.
diff --git a/docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png b/docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png
rename to docs/source/recipes/Streaming-ASR/librispeech/images/librispeech-lstm-transducer-tensorboard-log.png
diff --git a/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg
new file mode 100644
index 000000000..9c77b8bae
Binary files /dev/null and b/docs/source/recipes/Streaming-ASR/librispeech/images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg differ
diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/Streaming-ASR/librispeech/index.rst
similarity index 61%
rename from docs/source/recipes/librispeech/index.rst
rename to docs/source/recipes/Streaming-ASR/librispeech/index.rst
index 6c91b6750..d52e08058 100644
--- a/docs/source/recipes/librispeech/index.rst
+++ b/docs/source/recipes/Streaming-ASR/librispeech/index.rst
@@ -4,6 +4,8 @@ LibriSpeech
.. toctree::
:maxdepth: 1
- tdnn_lstm_ctc
- conformer_ctc
+ pruned_transducer_stateless
+
lstm_pruned_stateless_transducer
+
+ zipformer_transducer
diff --git a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst
similarity index 81%
rename from docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst
rename to docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst
index 643855cc2..911e84656 100644
--- a/docs/source/recipes/librispeech/lstm_pruned_stateless_transducer.rst
+++ b/docs/source/recipes/Streaming-ASR/librispeech/lstm_pruned_stateless_transducer.rst
@@ -515,110 +515,6 @@ To use the generated files with ``./lstm_transducer_stateless2/jit_pretrained``:
Please see ``_
for how to use the exported models in ``sherpa``.
-.. _export-model-for-ncnn:
-
-Export model for ncnn
-~~~~~~~~~~~~~~~~~~~~~
-
-We support exporting pretrained LSTM transducer models to
-`ncnn `_ using
-`pnnx `_.
-
-First, let us install a modified version of ``ncnn``:
-
-.. code-block:: bash
-
- git clone https://github.com/csukuangfj/ncnn
- cd ncnn
- git submodule update --recursive --init
- python3 setup.py bdist_wheel
- ls -lh dist/
- pip install ./dist/*.whl
-
- # now build pnnx
- cd tools/pnnx
- mkdir build
- cd build
- make -j4
- export PATH=$PWD/src:$PATH
-
- ./src/pnnx
-
-.. note::
-
- We assume that you have added the path to the binary ``pnnx`` to the
- environment variable ``PATH``.
-
-Second, let us export the model using ``torch.jit.trace()`` that is suitable
-for ``pnnx``:
-
-.. code-block:: bash
-
- iter=468000
- avg=16
-
- ./lstm_transducer_stateless2/export.py \
- --exp-dir ./lstm_transducer_stateless2/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
- --iter $iter \
- --avg $avg \
- --pnnx 1
-
-It will generate 3 files:
-
- - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt``
- - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt``
- - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt``
-
-Third, convert torchscript model to ``ncnn`` format:
-
-.. code-block::
-
- pnnx ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt
- pnnx ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt
- pnnx ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt
-
-It will generate the following files:
-
- - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param``
- - ``./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin``
- - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param``
- - ``./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin``
- - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param``
- - ``./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin``
-
-To use the above generated files, run:
-
-.. code-block:: bash
-
- ./lstm_transducer_stateless2/ncnn-decode.py \
- --bpe-model-filename ./data/lang_bpe_500/bpe.model \
- --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
- --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
- --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
- --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
- --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
- --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
- /path/to/foo.wav
-
-.. code-block:: bash
-
- ./lstm_transducer_stateless2/streaming-ncnn-decode.py \
- --bpe-model-filename ./data/lang_bpe_500/bpe.model \
- --encoder-param-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.param \
- --encoder-bin-filename ./lstm_transducer_stateless2/exp/encoder_jit_trace-pnnx.ncnn.bin \
- --decoder-param-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.param \
- --decoder-bin-filename ./lstm_transducer_stateless2/exp/decoder_jit_trace-pnnx.ncnn.bin \
- --joiner-param-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.param \
- --joiner-bin-filename ./lstm_transducer_stateless2/exp/joiner_jit_trace-pnnx.ncnn.bin \
- /path/to/foo.wav
-
-To use the above generated files in C++, please see
-``_
-
-It is able to generate a static linked executable that can be run on Linux, Windows,
-macOS, Raspberry Pi, etc, without external dependencies.
-
Download pretrained models
--------------------------
diff --git a/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst
new file mode 100644
index 000000000..de7102ba8
--- /dev/null
+++ b/docs/source/recipes/Streaming-ASR/librispeech/pruned_transducer_stateless.rst
@@ -0,0 +1,735 @@
+Pruned transducer statelessX
+============================
+
+This tutorial shows you how to run a **streaming** conformer transducer model
+with the `LibriSpeech `_ dataset.
+
+.. Note::
+
+ The tutorial is suitable for `pruned_transducer_stateless `_,
+ `pruned_transducer_stateless2 `_,
+ `pruned_transducer_stateless4 `_,
+ `pruned_transducer_stateless5 `_,
+ We will take pruned_transducer_stateless4 as an example in this tutorial.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+.. hint::
+
+ Please scroll down to the bottom of this page to find download links
+ for pretrained models if you don't want to train a model from scratch.
+
+
+We use pruned RNN-T to compute the loss.
+
+.. note::
+
+ You can find the paper about pruned RNN-T at the following address:
+
+ ``_
+
+The transducer model consists of 3 parts:
+
+ - Encoder, a.k.a, the transcription network. We use a Conformer model (the reworked version by Daniel Povey)
+ - Decoder, a.k.a, the prediction network. We use a stateless model consisting of
+ ``nn.Embedding`` and ``nn.Conv1d``
+ - Joiner, a.k.a, the joint network.
+
+.. caution::
+
+ Contrary to the conventional RNN-T models, we use a stateless decoder.
+ That is, it has no recurrent connections.
+
+
+Data preparation
+----------------
+
+.. hint::
+
+ The data preparation is the same as other recipes on LibriSpeech dataset,
+ if you have finished this step, you can skip to ``Training`` directly.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. HINT::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. NOTE::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+
+Training
+--------
+
+.. NOTE::
+
+ We put the streaming and non-streaming model in one recipe, to train a streaming model you only
+ need to add **4** extra options comparing with training a non-streaming model. These options are
+ ``--dynamic-chunk-training``, ``--num-left-chunks``, ``--causal-convolution``, ``--short-chunk-size``.
+ You can see the configurable options below for their meanings or read https://arxiv.org/pdf/2012.05481.pdf for more details.
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/train.py --help
+
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--exp-dir``
+
+ The directory to save checkpoints, training logs and tensorboard.
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./pruned_transducer_stateless4/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+ in the folder ``./pruned_transducer_stateless4/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./pruned_transducer_stateless4/train.py --start-epoch 10`` loads the
+ checkpoint ``./pruned_transducer_stateless4/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./pruned_transducer_stateless4/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./pruned_transducer_stateless4/train.py --world-size 1
+
+ .. caution::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+ - ``--use-fp16``
+
+ If it is True, the model will train with half precision, from our experiment
+ results, by using half precision you can train with two times larger ``--max-duration``
+ so as to get almost 2X speed up.
+
+ - ``--dynamic-chunk-training``
+
+ The flag that indicates whether to train a streaming model or not, it
+ **MUST** be True if you want to train a streaming model.
+
+ - ``--short-chunk-size``
+
+ When training a streaming attention model with chunk masking, the chunk size
+ would be either max sequence length of current batch or uniformly sampled from
+ (1, short_chunk_size). The default value is 25, you don't have to change it most of the time.
+
+ - ``--num-left-chunks``
+
+ It indicates how many left context (in chunks) that can be seen when calculating attention.
+ The default value is 4, you don't have to change it most of the time.
+
+
+ - ``--causal-convolution``
+
+ Whether to use causal convolution in conformer encoder layer, this requires
+ to be True when training a streaming model.
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., number of encoder layers,
+encoder dimension, decoder dimension, number of warmup steps etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`pruned_transducer_stateless4/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
+
+
+.. NOTE::
+
+ The options for `pruned_transducer_stateless5 `_ are a little different from
+ other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
+
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless4/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+ These are checkpoint files saved at the end of each epoch, containing model
+ ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless4/train.py --start-epoch 11
+
+ - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+ These are checkpoint files saved every ``--save-every-n`` batches,
+ containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless4/train.py --start-batch 436000
+
+ - ``tensorboard/``
+
+ This folder contains tensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd pruned_transducer_stateless4/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/
+
+ [2022-11-20T15:50:50] Started scanning logdir.
+ Uploading 4468 scalars...
+ [2022-11-20T15:53:02] Total uploaded: 210171 scalars, 0 tensors, 0 binary objects
+ Listening for new data in logdir...
+
+ Note there is a URL in the above output. Click it and you will see
+ the following screenshot:
+
+ .. figure:: images/streaming-librispeech-pruned-transducer-tensorboard-log.jpg
+ :width: 600
+ :alt: TensorBoard screenshot
+ :align: center
+ :target: https://tensorboard.dev/experiment/97VKXf80Ru61CnP2ALWZZg/
+
+ TensorBoard screenshot.
+
+ .. hint::
+
+ If you don't have access to google, you can use the following command
+ to view the tensorboard log locally:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless4/exp/tensorboard
+ tensorboard --logdir . --port 6008
+
+ It will print the following message:
+
+ .. code-block::
+
+ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+ TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+ Now start your browser and go to ``_ to view the tensorboard
+ logs.
+
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 4 GPUs:
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ ./pruned_transducer_stateless4/train.py \
+ --world-size 4 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+.. NOTE::
+
+ Comparing with training a non-streaming model, you only need to add two extra options,
+ ``--dynamic-chunk-training 1`` and ``--causal-convolution 1`` .
+
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+ There are two kinds of checkpoints:
+
+ - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+ of each epoch. You can pass ``--epoch`` to
+ ``pruned_transducer_stateless4/decode.py`` to use them.
+
+ - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+ every ``--save-every-n`` batches. You can pass ``--iter`` to
+ ``pruned_transducer_stateless4/decode.py`` to use them.
+
+ We suggest that you try both types of checkpoints and choose the one
+ that produces the lowest WERs.
+
+.. tip::
+
+ To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or
+ ``real streaming decoding`` in ``streaming_decode.py``, the difference between ``decode.py`` and
+ ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training),
+ but ``streaming_decode.py`` processes the acoustic frames chunk by chunk (so it can only see limited context).
+
+.. NOTE::
+
+ ``simulate streaming decoding`` in ``decode.py`` and ``real streaming decoding`` in ``streaming_decode.py`` should
+ produce almost the same results given the same ``--decode-chunk-size`` and ``--left-context``.
+
+
+Simulate streaming decoding
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/decode.py --help
+
+shows the options for decoding.
+The following options are important for streaming models:
+
+ ``--simulate-streaming``
+
+ If you want to decode a streaming model with ``decode.py``, you **MUST** set
+ ``--simulate-streaming`` to ``True``. ``simulate`` here means the acoustic frames
+ are not processed frame by frame (or chunk by chunk), instead, the whole sequence
+ is processed at one time with masking (the same as training).
+
+ ``--causal-convolution``
+
+ If True, the convolution module in encoder layers will be causal convolution.
+ This is **MUST** be True when decoding with a streaming model.
+
+ ``--decode-chunk-size``
+
+ For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size``
+ indicates the chunk length (in frames after subsampling) for chunk-wise attention.
+ For ``simulate streaming decoding`` the ``decode-chunk-size`` is used to generate
+ the attention mask.
+
+ ``--left-context``
+
+ ``--left-context`` indicates how many left context frames (after subsampling) can be seen
+ for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal
+ to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used
+ to train this model. For ``simulate streaming decoding`` the ``left-context`` is used to generate
+ the attention mask.
+
+
+The following shows two examples (for the two types of checkpoints):
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 25 20; do
+ for avg in 7 5 3 1; do
+ ./pruned_transducer_stateless4/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ --decode-chunk-size 16 \
+ --left-context 64 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for iter in 474000; do
+ for avg in 8 10 12 14 16 18; do
+ ./pruned_transducer_stateless4/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ --decode-chunk-size 16 \
+ --left-context 64 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+Real streaming decoding
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless4/streaming_decode.py --help
+
+shows the options for decoding.
+The following options are important for streaming models:
+
+ ``--decode-chunk-size``
+
+ For streaming models, we will calculate the chunk-wise attention, ``--decode-chunk-size``
+ indicates the chunk length (in frames after subsampling) for chunk-wise attention.
+ For ``real streaming decoding``, we will process ``decode-chunk-size`` acoustic frames at each time.
+
+ ``--left-context``
+
+ ``--left-context`` indicates how many left context frames (after subsampling) can be seen
+ for current chunk when calculating chunk-wise attention. Normally, ``left-context`` should equal
+ to ``decode-chunk-size * num-left-chunks``, where ``num-left-chunks`` is the option used
+ to train this model.
+
+ ``--num-decode-streams``
+
+ The number of decoding streams that can be run in parallel (very similar to the ``bath size``).
+ For ``real streaming decoding``, the batches will be packed dynamically, for example, if the
+ ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while,
+ suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch.
+
+
+.. NOTE::
+
+ We also try adding ``--right-context`` in the real streaming decoding, but it seems not to benefit
+ the performance for all the models, the reasons might be the training and decoding mismatch. You
+ can try decoding with ``--right-context`` to see if it helps. The default value is 0.
+
+
+The following shows two examples (for the two types of checkpoints):
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 25 20; do
+ for avg in 7 5 3 1; do
+ ./pruned_transducer_stateless4/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --decode-chunk-size 16 \
+ --left-context 64 \
+ --num-decode-streams 100 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for iter in 474000; do
+ for avg in 8 10 12 14 16 18; do
+ ./pruned_transducer_stateless4/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --decode-chunk-size 16 \
+ --left-context 64 \
+ --num-decode-streams 100 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. tip::
+
+ Supporting decoding methods are as follows:
+
+ - ``greedy_search`` : It takes the symbol with largest posterior probability
+ of each frame as the decoding result.
+
+ - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
+ `espnet/nets/beam_search_transducer.py `_
+ is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
+ next frame.
+
+ - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
+ runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded.
+
+ - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and
+ given ``FSAs``. It is hard to describe the details in several lines of texts, you can read
+ our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently.
+
+ - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses
+ an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph
+ (with N-gram LM).
+
+ - ``fast_beam_search_nbest`` : It produces the decoding results as follows:
+
+ - (1) Use ``fast_beam_search`` to get a lattice
+ - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()``
+ - (3) Unique the selected paths
+ - (4) Intersect the selected paths with the lattice and compute the
+ shortest path from the intersection result
+ - (5) The path with the largest score is used as the decoding output.
+
+ - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the
+ only difference is that it uses ``fast_beam_search_LG`` to generate the lattice.
+
+.. NOTE::
+
+ The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed,
+ you can implement them by yourself or file a issue in `icefall `_ .
+
+
+Export Model
+------------
+
+`pruned_transducer_stateless4/export.py `_ supports exporting checkpoints from ``pruned_transducer_stateless4/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``pruned_transducer_stateless4/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+ # Assume that --epoch 25 --avg 3 produces the smallest WER
+ # (You can get such information after running ./pruned_transducer_stateless4/decode.py)
+
+ epoch=25
+ avg=3
+
+ ./pruned_transducer_stateless4/export.py \
+ --exp-dir ./pruned_transducer_stateless4/exp \
+ --streaming-model 1 \
+ --causal-convolution 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch $epoch \
+ --avg $avg
+
+.. caution::
+
+ ``--streaming-model`` and ``--causal-convolution`` require to be True to export
+ a streaming mdoel.
+
+It will generate a file ``./pruned_transducer_stateless4/exp/pretrained.pt``.
+
+.. hint::
+
+ To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless4/decode.py``,
+ you can run:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless4/exp
+ ln -s pretrained.pt epoch-999.pt
+
+ And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to
+ ``./pruned_transducer_stateless4/decode.py``.
+
+To use the exported model with ``./pruned_transducer_stateless4/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless4/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless4/exp/pretrained.pt \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless4/export.py \
+ --exp-dir ./pruned_transducer_stateless4/exp \
+ --streaming-model 1 \
+ --causal-convolution 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 25 \
+ --avg 3 \
+ --jit 1
+
+.. caution::
+
+ ``--streaming-model`` and ``--causal-convolution`` require to be True to export
+ a streaming mdoel.
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+.. NOTE::
+
+ You will need this ``cpu_jit.pt`` when deploying with Sherpa framework.
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+ - `pruned_transducer_stateless `_
+
+ - `pruned_transducer_stateless2 `_
+
+ - `pruned_transducer_stateless4 `_
+
+ - `pruned_transducer_stateless5 `_
+
+ See ``_
+ for the details of the above pretrained models
+
+
+Deploy with Sherpa
+------------------
+
+Please see ``_
+for how to deploy the models in ``sherpa``.
diff --git a/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst
new file mode 100644
index 000000000..f0e8961d7
--- /dev/null
+++ b/docs/source/recipes/Streaming-ASR/librispeech/zipformer_transducer.rst
@@ -0,0 +1,654 @@
+Zipformer Transducer
+====================
+
+This tutorial shows you how to run a **streaming** zipformer transducer model
+with the `LibriSpeech `_ dataset.
+
+.. Note::
+
+ The tutorial is suitable for `pruned_transducer_stateless7_streaming `_,
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+.. hint::
+
+ Please scroll down to the bottom of this page to find download links
+ for pretrained models if you don't want to train a model from scratch.
+
+
+We use pruned RNN-T to compute the loss.
+
+.. note::
+
+ You can find the paper about pruned RNN-T at the following address:
+
+ ``_
+
+The transducer model consists of 3 parts:
+
+ - Encoder, a.k.a, the transcription network. We use a Zipformer model (proposed by Daniel Povey)
+ - Decoder, a.k.a, the prediction network. We use a stateless model consisting of
+ ``nn.Embedding`` and ``nn.Conv1d``
+ - Joiner, a.k.a, the joint network.
+
+.. caution::
+
+ Contrary to the conventional RNN-T models, we use a stateless decoder.
+ That is, it has no recurrent connections.
+
+
+Data preparation
+----------------
+
+.. hint::
+
+ The data preparation is the same as other recipes on LibriSpeech dataset,
+ if you have finished this step, you can skip to ``Training`` directly.
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh
+
+The script ``./prepare.sh`` handles the data preparation for you, **automagically**.
+All you need to do is to run it.
+
+The data preparation contains several stages, you can use the following two
+options:
+
+ - ``--stage``
+ - ``--stop-stage``
+
+to control which stage(s) should be run. By default, all stages are executed.
+
+
+For example,
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./prepare.sh --stage 0 --stop-stage 0
+
+means to run only stage 0.
+
+To run stage 2 to stage 5, use:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 2 --stop-stage 5
+
+.. HINT::
+
+ If you have pre-downloaded the `LibriSpeech `_
+ dataset and the `musan `_ dataset, say,
+ they are saved in ``/tmp/LibriSpeech`` and ``/tmp/musan``, you can modify
+ the ``dl_dir`` variable in ``./prepare.sh`` to point to ``/tmp`` so that
+ ``./prepare.sh`` won't re-download them.
+
+.. NOTE::
+
+ All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
+ are saved in ``./data`` directory.
+
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
+
+Training
+--------
+
+Configurable options
+~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_streaming/train.py --help
+
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--exp-dir``
+
+ The directory to save checkpoints, training logs and tensorboard.
+
+ - ``--full-libri``
+
+ If it's True, the training part uses all the training data, i.e.,
+ 960 hours. Otherwise, the training part uses only the subset
+ ``train-clean-100``, which has 100 hours of training data.
+
+ .. CAUTION::
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ If ``--full-libri`` is True, each epoch actually processes
+ ``3x960 == 2880`` hours of data.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./pruned_transducer_stateless7_streaming/train.py --num-epochs 30`` trains for 30 epochs
+ and generates ``epoch-1.pt``, ``epoch-2.pt``, ..., ``epoch-30.pt``
+ in the folder ``./pruned_transducer_stateless7_streaming/exp``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./pruned_transducer_stateless7_streaming/train.py --start-epoch 10`` loads the
+ checkpoint ``./pruned_transducer_stateless7_streaming/exp/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for multi-GPU single-machine DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./pruned_transducer_stateless7_streaming/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_streaming/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./pruned_transducer_stateless7_streaming/train.py --world-size 1
+
+ .. caution::
+
+ Only multi-GPU single-machine DDP training is implemented at present.
+ Multi-GPU multi-machine DDP training will be added later.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch, before **padding**.
+ If you encounter CUDA OOM, please reduce it.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+ - ``--use-fp16``
+
+ If it is True, the model will train with half precision, from our experiment
+ results, by using half precision you can train with two times larger ``--max-duration``
+ so as to get almost 2X speed up.
+
+ We recommend using ``--use-fp16 True``.
+
+ - ``--short-chunk-size``
+
+ When training a streaming attention model with chunk masking, the chunk size
+ would be either max sequence length of current batch or uniformly sampled from
+ (1, short_chunk_size). The default value is 50, you don't have to change it most of the time.
+
+ - ``--num-left-chunks``
+
+ It indicates how many left context (in chunks) that can be seen when calculating attention.
+ The default value is 4, you don't have to change it most of the time.
+
+
+ - ``--decode-chunk-len``
+
+ The chunk size for decoding (in frames before subsampling). It is used for validation.
+ The default value is 32 (i.e., 320ms).
+
+
+Pre-configured options
+~~~~~~~~~~~~~~~~~~~~~~
+
+There are some training options, e.g., number of encoder layers,
+encoder dimension, decoder dimension, number of warmup steps etc,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`pruned_transducer_stateless7_streaming/train.py `_
+
+You don't need to change these pre-configured parameters. If you really need to change
+them, please modify ``./pruned_transducer_stateless7_streaming/train.py`` directly.
+
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in ``--exp-dir`` (e.g. ``pruned_transducer_stateless7_streaming/exp``.
+You will find the following files in that directory:
+
+ - ``epoch-1.pt``, ``epoch-2.pt``, ...
+
+ These are checkpoint files saved at the end of each epoch, containing model
+ ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless7_streaming/train.py --start-epoch 11
+
+ - ``checkpoint-436000.pt``, ``checkpoint-438000.pt``, ...
+
+ These are checkpoint files saved every ``--save-every-n`` batches,
+ containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``checkpoint-436000``, you can use:
+
+ .. code-block:: bash
+
+ $ ./pruned_transducer_stateless7_streaming/train.py --start-batch 436000
+
+ - ``tensorboard/``
+
+ This folder contains tensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd pruned_transducer_stateless7_streaming/exp/tensorboard
+ $ tensorboard dev upload --logdir . --description "pruned transducer training for LibriSpeech with icefall"
+
+ .. hint::
+
+ If you don't have access to google, you can use the following command
+ to view the tensorboard log locally:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless7_streaming/exp/tensorboard
+ tensorboard --logdir . --port 6008
+
+ It will print the following message:
+
+ .. code-block::
+
+ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
+ TensorBoard 2.8.0 at http://localhost:6008/ (Press CTRL+C to quit)
+
+ Now start your browser and go to ``_ to view the tensorboard
+ logs.
+
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage example
+~~~~~~~~~~~~~
+
+You can use the following command to start the training using 4 GPUs:
+
+.. code-block:: bash
+
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ ./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --full-libri 1 \
+ --max-duration 550
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. hint::
+
+ There are two kinds of checkpoints:
+
+ - (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
+ of each epoch. You can pass ``--epoch`` to
+ ``pruned_transducer_stateless7_streaming/decode.py`` to use them.
+
+ - (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
+ every ``--save-every-n`` batches. You can pass ``--iter`` to
+ ``pruned_transducer_stateless7_streaming/decode.py`` to use them.
+
+ We suggest that you try both types of checkpoints and choose the one
+ that produces the lowest WERs.
+
+.. tip::
+
+ To decode a streaming model, you can use either ``simulate streaming decoding`` in ``decode.py`` or
+ ``real chunk-wise streaming decoding`` in ``streaming_decode.py``. The difference between ``decode.py`` and
+ ``streaming_decode.py`` is that, ``decode.py`` processes the whole acoustic frames at one time with masking (i.e. same as training),
+ but ``streaming_decode.py`` processes the acoustic frames chunk by chunk.
+
+.. NOTE::
+
+ ``simulate streaming decoding`` in ``decode.py`` and ``real chunk-size streaming decoding`` in ``streaming_decode.py`` should
+ produce almost the same results given the same ``--decode-chunk-len``.
+
+
+Simulate streaming decoding
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_streaming/decode.py --help
+
+shows the options for decoding.
+The following options are important for streaming models:
+
+ ``--decode-chunk-len``
+
+ It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling).
+ The default value is 32 (i.e., 320ms).
+
+
+The following shows two examples (for the two types of checkpoints):
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 30; do
+ for avg in 12 11 10 9 8; do
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --decode-chunk-len 32 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for iter in 474000; do
+ for avg in 8 10 12 14 16 18; do
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --decode-chunk-len 32 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+Real streaming decoding
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ $ cd egs/librispeech/ASR
+ $ ./pruned_transducer_stateless7_streaming/streaming_decode.py --help
+
+shows the options for decoding.
+The following options are important for streaming models:
+
+ ``--decode-chunk-len``
+
+ It is same as in ``train.py``, which specifies the chunk size for decoding (in frames before subsampling).
+ The default value is 32 (i.e., 320ms).
+ For ``real streaming decoding``, we will process ``decode-chunk-len`` acoustic frames at each time.
+
+ ``--num-decode-streams``
+
+ The number of decoding streams that can be run in parallel (very similar to the ``bath size``).
+ For ``real streaming decoding``, the batches will be packed dynamically, for example, if the
+ ``num-decode-streams`` equals to 10, then, sequence 1 to 10 will be decoded at first, after a while,
+ suppose sequence 1 and 2 are done, so, sequence 3 to 12 will be processed parallelly in a batch.
+
+
+The following shows two examples (for the two types of checkpoints):
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 30; do
+ for avg in 12 11 10 9 8; do
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --decode-chunk-len 32 \
+ --num-decode-streams 100 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. code-block:: bash
+
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ for iter in 474000; do
+ for avg in 8 10 12 14 16 18; do
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --iter $iter \
+ --avg $avg \
+ --decode-chunk-len 16 \
+ --num-decode-streams 100 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --decoding-method $m
+ done
+ done
+ done
+
+
+.. tip::
+
+ Supporting decoding methods are as follows:
+
+ - ``greedy_search`` : It takes the symbol with largest posterior probability
+ of each frame as the decoding result.
+
+ - ``beam_search`` : It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf and
+ `espnet/nets/beam_search_transducer.py `_
+ is used as a reference. Basicly, it keeps topk states for each frame, and expands the kept states with their own contexts to
+ next frame.
+
+ - ``modified_beam_search`` : It implements the same algorithm as ``beam_search`` above, but it
+ runs in batch mode with ``--max-sym-per-frame=1`` being hardcoded.
+
+ - ``fast_beam_search`` : It implements graph composition between the output ``log_probs`` and
+ given ``FSAs``. It is hard to describe the details in several lines of texts, you can read
+ our paper in https://arxiv.org/pdf/2211.00484.pdf or our `rnnt decode code in k2 `_. ``fast_beam_search`` can decode with ``FSAs`` on GPU efficiently.
+
+ - ``fast_beam_search_LG`` : The same as ``fast_beam_search`` above, ``fast_beam_search`` uses
+ an trivial graph that has only one state, while ``fast_beam_search_LG`` uses an LG graph
+ (with N-gram LM).
+
+ - ``fast_beam_search_nbest`` : It produces the decoding results as follows:
+
+ - (1) Use ``fast_beam_search`` to get a lattice
+ - (2) Select ``num_paths`` paths from the lattice using ``k2.random_paths()``
+ - (3) Unique the selected paths
+ - (4) Intersect the selected paths with the lattice and compute the
+ shortest path from the intersection result
+ - (5) The path with the largest score is used as the decoding output.
+
+ - ``fast_beam_search_nbest_LG`` : It implements same logic as ``fast_beam_search_nbest``, the
+ only difference is that it uses ``fast_beam_search_LG`` to generate the lattice.
+
+.. NOTE::
+
+ The supporting decoding methods in ``streaming_decode.py`` might be less than that in ``decode.py``, if needed,
+ you can implement them by yourself or file a issue in `icefall `_ .
+
+
+Export Model
+------------
+
+Currently it supports exporting checkpoints from ``pruned_transducer_stateless7_streaming/exp`` in the following ways.
+
+Export ``model.state_dict()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Checkpoints saved by ``pruned_transducer_stateless7_streaming/train.py`` also include
+``optimizer.state_dict()``. It is useful for resuming training. But after training,
+we are interested only in ``model.state_dict()``. You can use the following
+command to extract ``model.state_dict()``.
+
+.. code-block:: bash
+
+ # Assume that --epoch 30 --avg 9 produces the smallest WER
+ # (You can get such information after running ./pruned_transducer_stateless7_streaming/decode.py)
+
+ epoch=30
+ avg=9
+
+ ./pruned_transducer_stateless7_streaming/export.py \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model=True \
+ --decode-chunk-len 32
+
+It will generate a file ``./pruned_transducer_stateless7_streaming/exp/pretrained.pt``.
+
+.. hint::
+
+ To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_streaming/decode.py``,
+ you can run:
+
+ .. code-block:: bash
+
+ cd pruned_transducer_stateless7_streaming/exp
+ ln -s pretrained.pt epoch-999.pt
+
+ And then pass ``--epoch 999 --avg 1 --use-averaged-model 0`` to
+ ``./pruned_transducer_stateless7_streaming/decode.py``.
+
+To use the exported model with ``./pruned_transducer_stateless7_streaming/pretrained.py``, you
+can run:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_streaming/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ --decode-chunk-len 32 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+Export model using ``torch.jit.script()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_streaming/export.py \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --decode-chunk-len 32 \
+ --jit 1
+
+.. caution::
+
+ ``--decode-chunk-len`` is required to export a ScriptModule.
+
+It will generate a file ``cpu_jit.pt`` in the given ``exp_dir``. You can later
+load it by ``torch.jit.load("cpu_jit.pt")``.
+
+Note ``cpu`` in the name ``cpu_jit.pt`` means the parameters when loaded into Python
+are on CPU. You can use ``to("cuda")`` to move them to a CUDA device.
+
+Export model using ``torch.jit.trace()``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. code-block:: bash
+
+ epoch=30
+ avg=9
+
+ ./pruned_transducer_stateless7_streaming/jit_trace_export.py \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --use-averaged-model=True \
+ --decode-chunk-len 32 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --epoch $epoch \
+ --avg $avg
+
+.. caution::
+
+ ``--decode-chunk-len`` is required to export a ScriptModule.
+
+It will generate 3 files:
+
+ - ``./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt``
+ - ``./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt``
+ - ``./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt``
+
+To use the generated files with ``./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py``:
+
+.. code-block:: bash
+
+ ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
+ --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
+ --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
+ --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ /path/to/foo.wav
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following links:
+
+ - `pruned_transducer_stateless7_streaming `_
+
+ See ``_
+ for the details of the above pretrained models
+
+Deploy with Sherpa
+------------------
+
+Please see ``_
+for how to deploy the models in ``sherpa``.
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 9d1d83d29..63793275c 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -13,7 +13,5 @@ We may add recipes for other tasks as well in the future.
:maxdepth: 2
:caption: Table of Contents
- aishell/index
- librispeech/index
- timit/index
- yesno/index
+ Non-streaming-ASR/index
+ Streaming-ASR/index
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
index fb2751c0f..387c14acf 100755
--- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_char.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
- pieces = [
- token2id[i] if i in token2id else token2id[""] for i in pieces
- ]
+ pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
-def generate_lexicon(
- token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir", type=str, help="The lang dir, data/lang_phone"
- )
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
- fsa_disambig = lexicon_to_fst(
- lexicon_disambig, phone2id=phone2id, word2id=word2id
- )
+ fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aidatatang_200zh/ASR/local/text2token.py
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
- parser.add_argument(
- "--space", default="", type=str, help="space symbol"
- )
+ parser.add_argument("--space", default="", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., etc.",
)
- parser.add_argument(
- "text", type=str, default=False, nargs="?", help="input text"
- )
+ parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
- token_table[txt] if txt in token_table else oov_id
- for txt in text
+ token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
- f = codecs.getreader("utf-8")(
- sys.stdin if is_python2 else sys.stdin.buffer
- )
+ f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 039951354..46ecd5769 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -1,5 +1,8 @@
#!/usr/bin/env bash
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
set -eou pipefail
stage=-1
@@ -106,11 +109,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
if [ ! -f $lang_char_dir/words.txt ]; then
./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
- --output-file $lang_char_dir/words.txt
+ --output-file $lang_char_dir/words.txt
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
./local/prepare_char.py
fi
fi
-
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 6a5b57e24..167d5e15e 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class Aidatatang_200zhAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(
- self.args.manifest_dir / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class Aidatatang_200zhAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -282,9 +276,7 @@ class Aidatatang_200zhAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@@ -340,9 +332,7 @@ class Aidatatang_200zhAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
index f0407f429..2512f233f 100755
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -69,11 +69,7 @@ from beam_search import (
)
from train import get_params, get_transducer_model
-from icefall.checkpoint import (
- average_checkpoints,
- find_checkpoints,
- load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@@ -192,8 +188,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -403,18 +391,14 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
@@ -424,10 +408,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
index 00b54c39f..e348f7b2b 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
return parser
@@ -173,9 +172,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index eb5e6b0d4..75c316eaf 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -193,10 +192,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -257,9 +255,7 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
@@ -284,10 +280,7 @@ def main():
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
index d46838b68..c9d9c4aa8 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@@ -187,8 +185,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md
index 75fc6326e..f4a59e552 100644
--- a/egs/aishell/ASR/README.md
+++ b/egs/aishell/ASR/README.md
@@ -1,7 +1,7 @@
# Introduction
-Please refer to
+Please refer to
for how to run models in this recipe.
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 9a06fbe9f..aa18502c2 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -2,6 +2,57 @@
### Aishell training result(Stateless Transducer)
+#### Pruned transducer stateless 7 (zipformer)
+
+See
+
+[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
+
+**Note**: The modeling units are byte level BPEs
+
+The best results I have gotten are:
+
+Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
+-- | -- | -- | -- | -- | --
+500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
+
+The training command:
+
+```
+export CUDA_VISIBLE_DEVICES="4,5,6,7"
+
+./pruned_transducer_stateless7_bbpe/train.py \
+ --world-size 4 \
+ --num-epochs 50 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --max-duration 800 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --exp-dir pruned_transducer_stateless7_bbpe/exp \
+ --lr-epochs 6 \
+ --master-port 12535
+```
+
+The decoding command:
+
+```
+for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
+ ./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 48 \
+ --avg 29 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-sym-per-frame 1 \
+ --ngram-lm-scale 0.25 \
+ --ilme-scale 0.2 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 2000 \
+ --decoding-method $m
+done
+```
+
+The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
+
+
#### Pruned transducer stateless 3
See
@@ -15,6 +66,8 @@ It uses pruned RNN-T.
|------------------------|------|------|---------------------------------------|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
+| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 |
+| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 |
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
Training command is:
@@ -73,6 +126,78 @@ for epoch in 29; do
done
```
+We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
+available at . To decode with the language model,
+please use the following command:
+
+```bash
+# download pre-trained model
+git lfs install
+git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
+
+aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/
+
+pushd ${aishell_exp}/exp
+ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt
+popd
+
+# download RNN LM
+git lfs install
+git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm
+rnnlm_dir=icefall-aishell-rnn-lm
+
+# RNNLM shallow fusion
+for lm_scale in $(seq 0.26 0.02 0.34); do
+ python ./pruned_transducer_stateless3/decode.py \
+ --epoch 99 \
+ --avg 1 \
+ --lang-dir ${aishell_exp}/data/lang_char \
+ --exp-dir ${aishell_exp}/exp \
+ --use-averaged-model False \
+ --decoding-method modified_beam_search_lm_shallow_fusion \
+ --use-shallow-fusion 1 \
+ --lm-type rnn \
+ --lm-exp-dir ${rnnlm_dir}/exp \
+ --lm-epoch 99 \
+ --lm-scale $lm_scale \
+ --lm-avg 1 \
+ --rnn-lm-embedding-dim 2048 \
+ --rnn-lm-hidden-dim 2048 \
+ --rnn-lm-num-layers 2 \
+ --lm-vocab-size 4336
+done
+
+# RNNLM Low-order density ratio (LODR) with a 2-gram
+
+cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt
+
+for lm_scale in 0.48; do
+ for LODR_scale in -0.28; do
+ python ./pruned_transducer_stateless3/decode.py \
+ --epoch 99 \
+ --avg 1 \
+ --lang-dir ${aishell_exp}/data/lang_char \
+ --exp-dir ${aishell_exp}/exp \
+ --use-averaged-model False \
+ --decoding-method modified_beam_search_LODR \
+ --use-shallow-fusion 1 \
+ --lm-type rnn \
+ --lm-exp-dir ${rnnlm_dir}/exp \
+ --lm-epoch 99 \
+ --lm-scale $lm_scale \
+ --lm-avg 1 \
+ --rnn-lm-embedding-dim 2048 \
+ --rnn-lm-hidden-dim 2048 \
+ --rnn-lm-num-layers 2 \
+ --lm-vocab-size 4336 \
+ --tokens-ngram 2 \
+ --backoff-id 4336 \
+ --ngram-lm-scale $LODR_scale
+ done
+done
+
+```
+
Pretrained models, training logs, decoding logs, and decoding results
are available at
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
- self.self_attn = RelPositionMultiheadAttention(
- d_model, nhead, dropout=0.0
- )
+ self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
- self.norm_ff_macaron = nn.LayerNorm(
- d_model
- ) # for the macaron style FNN module
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
- self.norm_final = nn.LayerNorm(
- d_model
- ) # for the final output of the block
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
- src = residual + self.ff_scale * self.dropout(
- self.feed_forward_macaron(src)
- )
+ src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
- def __init__(
- self, d_model: int, dropout_rate: float, max_len: int = 5000
- ) -> None:
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
- if self.pe.dtype != x.dtype or str(self.pe.device) != str(
- x.device
- ):
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
- q, k, v = nn.functional.linear(
- query, in_proj_weight, in_proj_bias
- ).chunk(3, dim=-1)
+ q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+ 3, dim=-1
+ )
elif torch.equal(key, value):
# encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
- raise RuntimeError(
- "The size of the 2D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
- raise RuntimeError(
- "The size of the 3D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
- "attn_mask's dimension {} is not supported".format(
- attn_mask.dim()
- )
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
- if (
- key_padding_mask is not None
- and key_padding_mask.dtype == torch.uint8
- ):
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
- matrix_ac = torch.matmul(
- q_with_bias_u, k
- ) # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
- attn_output_weights = attn_output_weights.view(
- bsz * num_heads, tgt_len, -1
- )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
- attn_output.transpose(0, 1)
- .contiguous()
- .view(tgt_len, bsz, embed_dim)
- )
- attn_output = nn.functional.linear(
- attn_output, out_proj_weight, out_proj_bias
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
+ attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
"""
- def __init__(
- self, channels: int, kernel_size: int, bias: bool = True
- ) -> None:
+ def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index 751b7d5b5..74a7b5933 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -401,9 +401,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -431,9 +429,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -441,9 +437,7 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
- logging.info(
- "Wrote detailed error stats to {}".format(errs_filename)
- )
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -562,9 +556,7 @@ def main():
eos_id=eos_id,
)
- save_results(
- params=params, test_set_name=test_set, results_dict=results_dict
- )
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py
index 42b8c29e7..1df3cfdc2 100644
--- a/egs/aishell/ASR/conformer_ctc/export.py
+++ b/egs/aishell/ASR/conformer_ctc/export.py
@@ -157,9 +157,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index 27776bc24..66d583396 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -210,10 +210,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -274,9 +273,7 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
@@ -371,9 +368,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 542fb0364..8e0f73d05 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
- nn.Conv2d(
- in_channels=1, out_channels=odim, kernel_size=3, stride=2
- ),
+ nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
- nn.Conv2d(
- in_channels=odim, out_channels=odim, kernel_size=3, stride=2
- ),
+ nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
- torch.nn.MaxPool2d(
- kernel_size=2, stride=2, padding=0, ceil_mode=True
- )
+ torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
- self.out = nn.Linear(
- block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
- )
+ self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
index e3361d0c9..81fa234dd 100755
--- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/test_subsampling.py
@@ -16,9 +16,8 @@
# limitations under the License.
-from subsampling import Conv2dSubsampling
-from subsampling import VggSubsampling
import torch
+from subsampling import Conv2dSubsampling, VggSubsampling
def test_conv2d_subsampling():
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index a228cc1fe..c2cbe6e3b 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -382,9 +382,7 @@ def compute_loss(
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
- unsorted_token_ids = graph_compiler.texts_to_ids(
- supervisions["text"]
- )
+ unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
@@ -520,9 +518,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -630,9 +626,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
+ tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_ctc/transformer.py
+++ b/egs/aishell/ASR/conformer_ctc/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
- self.decoder_output_layer = torch.nn.Linear(
- d_model, self.decoder_num_class
- )
+ self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
- encoder_memory, memory_key_padding_mask = self.run_encoder(
- x, supervision
- )
+ encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
- ys_in_pad = pad_sequence(
- ys_in, batch_first=True, padding_value=float(eos_id)
- )
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
- ys_out_pad = pad_sequence(
- ys_out, batch_first=True, padding_value=float(-1)
- )
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
- tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
- device
- )
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
- ys_in_pad = pad_sequence(
- ys_in, batch_first=True, padding_value=float(eos_id)
- )
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
- ys_out_pad = pad_sequence(
- ys_out, batch_first=True, padding_value=float(-1)
- )
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
- tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
- device
- )
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
- raise RuntimeError(
- "activation should be relu/gelu, not {}".format(activation)
- )
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
1,
).to(torch.int32)
- lengths = [
- 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
- ]
+ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask
-def decoder_padding_mask(
- ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index cb7205e51..ab1cbbae4 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -157,9 +157,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
- self.self_attn = RelPositionMultiheadAttention(
- d_model, nhead, dropout=0.0
- )
+ self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@@ -177,18 +175,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
- self.norm_ff_macaron = nn.LayerNorm(
- d_model
- ) # for the macaron style FNN module
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
- self.norm_final = nn.LayerNorm(
- d_model
- ) # for the final output of the block
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@@ -222,9 +216,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
- src = residual + self.ff_scale * self.dropout(
- self.feed_forward_macaron(src)
- )
+ src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@@ -343,9 +335,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
- def __init__(
- self, d_model: int, dropout_rate: float, max_len: int = 5000
- ) -> None:
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@@ -361,9 +351,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
- if self.pe.dtype != x.dtype or str(self.pe.device) != str(
- x.device
- ):
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@@ -633,9 +621,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
- q, k, v = nn.functional.linear(
- query, in_proj_weight, in_proj_bias
- ).chunk(3, dim=-1)
+ q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+ 3, dim=-1
+ )
elif torch.equal(key, value):
# encoder-decoder attention
@@ -703,31 +691,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
- raise RuntimeError(
- "The size of the 2D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
- raise RuntimeError(
- "The size of the 3D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
- "attn_mask's dimension {} is not supported".format(
- attn_mask.dim()
- )
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
- if (
- key_padding_mask is not None
- and key_padding_mask.dtype == torch.uint8
- ):
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@@ -766,9 +745,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
- matrix_ac = torch.matmul(
- q_with_bias_u, k
- ) # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@@ -780,9 +757,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
- attn_output_weights = attn_output_weights.view(
- bsz * num_heads, tgt_len, -1
- )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@@ -816,13 +791,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
- attn_output.transpose(0, 1)
- .contiguous()
- .view(tgt_len, bsz, embed_dim)
- )
- attn_output = nn.functional.linear(
- attn_output, out_proj_weight, out_proj_bias
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
+ attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@@ -845,9 +816,7 @@ class ConvolutionModule(nn.Module):
"""
- def __init__(
- self, channels: int, kernel_size: int, bias: bool = True
- ) -> None:
+ def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 4db367e36..20a855e7f 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -413,9 +413,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -443,9 +441,7 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@@ -453,9 +449,7 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
- logging.info(
- "Wrote detailed error stats to {}".format(errs_filename)
- )
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
@@ -550,9 +544,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
- torch.save(
- {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
- )
+ torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
return
model.to(device)
@@ -581,9 +573,7 @@ def main():
eos_id=eos_id,
)
- save_results(
- params=params, test_set_name=test_set, results_dict=results_dict
- )
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
diff --git a/egs/aishell/ASR/conformer_mmi/subsampling.py b/egs/aishell/ASR/conformer_mmi/subsampling.py
index 720ed6c22..398837a46 100644
--- a/egs/aishell/ASR/conformer_mmi/subsampling.py
+++ b/egs/aishell/ASR/conformer_mmi/subsampling.py
@@ -42,13 +42,9 @@ class Conv2dSubsampling(nn.Module):
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
- nn.Conv2d(
- in_channels=1, out_channels=odim, kernel_size=3, stride=2
- ),
+ nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
- nn.Conv2d(
- in_channels=odim, out_channels=odim, kernel_size=3, stride=2
- ),
+ nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
@@ -132,17 +128,13 @@ class VggSubsampling(nn.Module):
)
)
layers.append(
- torch.nn.MaxPool2d(
- kernel_size=2, stride=2, padding=0, ceil_mode=True
- )
+ torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
- self.out = nn.Linear(
- block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
- )
+ self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 685831d09..09cd6e60c 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -511,9 +511,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -625,9 +623,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
+ tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
diff --git a/egs/aishell/ASR/conformer_mmi/transformer.py b/egs/aishell/ASR/conformer_mmi/transformer.py
index f93914aaa..a3e50e385 100644
--- a/egs/aishell/ASR/conformer_mmi/transformer.py
+++ b/egs/aishell/ASR/conformer_mmi/transformer.py
@@ -149,9 +149,7 @@ class Transformer(nn.Module):
norm=decoder_norm,
)
- self.decoder_output_layer = torch.nn.Linear(
- d_model, self.decoder_num_class
- )
+ self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class)
self.decoder_criterion = LabelSmoothingLoss()
else:
@@ -183,9 +181,7 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
- encoder_memory, memory_key_padding_mask = self.run_encoder(
- x, supervision
- )
+ encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
@@ -266,23 +262,17 @@ class Transformer(nn.Module):
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
- ys_in_pad = pad_sequence(
- ys_in, batch_first=True, padding_value=float(eos_id)
- )
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
- ys_out_pad = pad_sequence(
- ys_out, batch_first=True, padding_value=float(-1)
- )
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
- tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
- device
- )
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@@ -343,23 +333,17 @@ class Transformer(nn.Module):
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
- ys_in_pad = pad_sequence(
- ys_in, batch_first=True, padding_value=float(eos_id)
- )
+ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id))
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
- ys_out_pad = pad_sequence(
- ys_out, batch_first=True, padding_value=float(-1)
- )
+ ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1))
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
- tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
- device
- )
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
@@ -632,9 +616,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
- raise RuntimeError(
- "activation should be relu/gelu, not {}".format(activation)
- )
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
@@ -836,9 +818,7 @@ def encoder_padding_mask(
1,
).to(torch.int32)
- lengths = [
- 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
- ]
+ lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
@@ -859,9 +839,7 @@ def encoder_padding_mask(
return mask
-def decoder_padding_mask(
- ys_pad: torch.Tensor, ignore_id: int = -1
-) -> torch.Tensor:
+def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
diff --git a/egs/aishell/ASR/local/compile_lg.py b/egs/aishell/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/aishell/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
index 42700a972..037971927 100755
--- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -87,9 +87,7 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@@ -116,9 +114,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index deab6c809..115ca1031 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py
index d9e47d17a..8cc0502c2 100755
--- a/egs/aishell/ASR/local/prepare_char.py
+++ b/egs/aishell/ASR/local/prepare_char.py
@@ -33,6 +33,7 @@ and generates the following files in the directory `lang_dir`:
- tokens.txt
"""
+import argparse
import re
from pathlib import Path
from typing import Dict, List
@@ -86,9 +87,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
- pieces = [
- token2id[i] if i in token2id else token2id[""] for i in pieces
- ]
+ pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@@ -142,9 +141,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
-def generate_lexicon(
- token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
@@ -193,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
return tokens
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain the bpe.model and words.txt
+ """,
+ )
+
+ return parser.parse_args()
+
+
def main():
- lang_dir = Path("data/lang_char")
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py
new file mode 100755
index 000000000..e7995680b
--- /dev/null
+++ b/egs/aishell/ASR/local/prepare_char_lm_training_data.py
@@ -0,0 +1,164 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
+# Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes a `tokens.txt` and a text file such as
+./download/lm/aishell-transcript.txt
+and outputs the LM training data to a supplied directory such
+as data/lm_training_char. The format is as follows:
+It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
+representation of a dict with the same format with librispeech receipe
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-char",
+ type=str,
+ help="""Lang dir of asr model, e.g. data/lang_char""",
+ )
+ parser.add_argument(
+ "--lm-data",
+ type=str,
+ help="""Input LM training data as text, e.g.
+ download/lm/aishell-train-word.txt""",
+ )
+ parser.add_argument(
+ "--lm-archive",
+ type=str,
+ help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt;
+ look at the source of this script to see the format.""",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ if Path(args.lm_archive).exists():
+ logging.warning(f"{args.lm_archive} exists - skipping")
+ return
+
+ # make token_dict from tokens.txt in order to map characters to tokens.
+ token_dict = {}
+ token_file = args.lang_char + "/tokens.txt"
+
+ with open(token_file, "r") as f:
+ for line in f.readlines():
+ line_list = line.split()
+ token_dict[line_list[0]] = int(line_list[1])
+
+ # word2index is a dictionary from words to integer ids. No need to reserve
+ # space for epsilon, etc.; the words are just used as a convenient way to
+ # compress the sequences of tokens.
+ word2index = dict()
+
+ word2token = [] # Will be a list-of-list-of-int, representing tokens.
+ sentences = [] # Will be a list-of-list-of-int, representing word-ids.
+
+ if "aishell-lm" in args.lm_data:
+ num_lines_in_total = 120098.0
+ step = 50000
+ elif "valid" in args.lm_data:
+ num_lines_in_total = 14326.0
+ step = 3000
+ elif "test" in args.lm_data:
+ num_lines_in_total = 7176.0
+ step = 3000
+ else:
+ num_lines_in_total = None
+ step = None
+
+ processed = 0
+
+ with open(args.lm_data) as f:
+ while True:
+ line = f.readline()
+ if line == "":
+ break
+
+ if step and processed % step == 0:
+ logging.info(
+ f"Processed number of lines: {processed} "
+ f"({processed / num_lines_in_total * 100: .3f}%)"
+ )
+ processed += 1
+
+ line_words = line.split()
+ for w in line_words:
+ if w not in word2index:
+ w_token = []
+ for t in w:
+ if t in token_dict:
+ w_token.append(token_dict[t])
+ else:
+ w_token.append(token_dict[""])
+ word2index[w] = len(word2token)
+ word2token.append(w_token)
+ sentences.append([word2index[w] for w in line_words])
+
+ logging.info("Constructing ragged tensors")
+ words = k2.ragged.RaggedTensor(word2token)
+ sentences = k2.ragged.RaggedTensor(sentences)
+
+ output = dict(words=words, sentences=sentences)
+
+ num_sentences = sentences.dim0
+ logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
+ sentence_lengths = [0] * num_sentences
+ for i in range(num_sentences):
+ if step and i % step == 0:
+ logging.info(
+ f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)"
+ )
+
+ word_ids = sentences[i]
+
+ # NOTE: If word_ids is a tensor with only 1 entry,
+ # token_ids is a torch.Tensor
+ token_ids = words[word_ids]
+ if isinstance(token_ids, k2.RaggedTensor):
+ token_ids = token_ids.values
+
+ # token_ids is a 1-D tensor containing the BPE tokens
+ # of the current sentence
+
+ sentence_lengths[i] = token_ids.numel()
+
+ output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32)
+
+ torch.save(output, args.lm_archive)
+ logging.info(f"Saved to {args.lm_archive}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir", type=str, help="The lang dir, data/lang_phone"
- )
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()
diff --git a/egs/aishell/ASR/local/prepare_lang_bbpe.py b/egs/aishell/ASR/local/prepare_lang_bbpe.py
new file mode 100755
index 000000000..ddd90622e
--- /dev/null
+++ b/egs/aishell/ASR/local/prepare_lang_bbpe.py
@@ -0,0 +1,267 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/bbpe.model,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import argparse
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+from icefall.byte_utils import byte_encode
+from icefall.utils import str2bool, tokenize_by_CJK_char
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [token2id[i] for i in pieces]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def generate_lexicon(
+ model_file: str, words: List[str], oov: str
+) -> Tuple[Lexicon, Dict[str, int]]:
+ """Generate a lexicon from a BPE model.
+
+ Args:
+ model_file:
+ Path to a sentencepiece model.
+ words:
+ A list of strings representing words.
+ oov:
+ The out of vocabulary word in lexicon.
+ Returns:
+ Return a tuple with two elements:
+ - A dict whose keys are words and values are the corresponding
+ word pieces.
+ - A dict representing the token symbol, mapping from tokens to IDs.
+ """
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(model_file))
+
+ # Convert word to word piece IDs instead of word piece strings
+ # to avoid OOV tokens.
+ encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
+ words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
+
+ # Now convert word piece IDs back to word piece strings.
+ words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
+
+ lexicon = []
+ for word, pieces in zip(words, words_pieces):
+ lexicon.append((word, pieces))
+
+ lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
+
+ token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
+
+ return lexicon, token2id
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain the bpe.model and words.txt
+ """,
+ )
+
+ parser.add_argument(
+ "--oov",
+ type=str,
+ default="",
+ help="The out of vocabulary word in lexicon.",
+ )
+
+ parser.add_argument(
+ "--debug",
+ type=str2bool,
+ default=False,
+ help="""True for debugging, which will generate
+ a visualization of the lexicon FST.
+
+ Caution: If your lexicon contains hundreds of thousands
+ of lines, please set it to False!
+
+ See "test/test_bpe_lexicon.py" for usage.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ model_file = lang_dir / "bbpe.model"
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = ["", "!SIL", "", args.oov, "#0", "", ""]
+
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+ if args.debug:
+ labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
+ aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ L.labels_sym = labels_sym
+ L.aux_labels_sym = aux_labels_sym
+ L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
+
+ L_disambig.labels_sym = labels_sym
+ L_disambig.aux_labels_sym = aux_labels_sym
+ L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py
new file mode 120000
index 000000000..1d6ccbe33
--- /dev/null
+++ b/egs/aishell/ASR/local/sort_lm_training_data.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/sort_lm_training_data.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/local/test_prepare_lang.py b/egs/aishell/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell/ASR/local/test_prepare_lang.py
+++ b/egs/aishell/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
- fsa_disambig = lexicon_to_fst(
- lexicon_disambig, phone2id=phone2id, word2id=word2id
- )
+ fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell/ASR/local/train_bbpe_model.py b/egs/aishell/ASR/local/train_bbpe_model.py
new file mode 100755
index 000000000..d231d5d77
--- /dev/null
+++ b/egs/aishell/ASR/local/train_bbpe_model.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# You can install sentencepiece via:
+#
+# pip install sentencepiece
+#
+# Due to an issue reported in
+# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
+#
+# Please install a version >=0.1.96
+
+import argparse
+import re
+import shutil
+import tempfile
+from pathlib import Path
+
+import sentencepiece as spm
+from icefall import byte_encode, tokenize_by_CJK_char
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ The generated bpe.model is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--transcript",
+ type=str,
+ help="Training transcript.",
+ )
+
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ help="Vocabulary size for BPE training",
+ )
+
+ return parser.parse_args()
+
+
+def _convert_to_bchar(in_path: str, out_path: str):
+ with open(out_path, "w") as f:
+ for line in open(in_path, "r").readlines():
+ f.write(byte_encode(tokenize_by_CJK_char(line)) + "\n")
+
+
+def main():
+ args = get_args()
+ vocab_size = args.vocab_size
+ lang_dir = Path(args.lang_dir)
+
+ model_type = "unigram"
+
+ model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
+ character_coverage = 1.0
+ input_sentence_size = 100000000
+
+ user_defined_symbols = ["", ""]
+ unk_id = len(user_defined_symbols)
+ # Note: unk_id is fixed to 2.
+ # If you change it, you should also change other
+ # places that are using it.
+
+ temp = tempfile.NamedTemporaryFile()
+ train_text = temp.name
+
+ _convert_to_bchar(args.transcript, train_text)
+
+ model_file = Path(model_prefix + ".model")
+ if not model_file.is_file():
+ spm.SentencePieceTrainer.train(
+ input=train_text,
+ vocab_size=vocab_size,
+ model_type=model_type,
+ model_prefix=model_prefix,
+ input_sentence_size=input_sentence_size,
+ character_coverage=character_coverage,
+ user_defined_symbols=user_defined_symbols,
+ unk_id=unk_id,
+ bos_id=-1,
+ eos_id=-1,
+ )
+ else:
+ print(f"{model_file} exists - skipping")
+ return
+
+ shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index eaeecfc4a..b763d72c1 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -1,10 +1,13 @@
#!/usr/bin/env bash
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
set -eou pipefail
nj=15
stage=-1
-stop_stage=10
+stop_stage=11
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@@ -32,6 +35,15 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
+# vocab size for sentence piece models.
+# It will generate data/lang_bbpe_xxx,
+# data/lang_bbpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 2000
+ # 1000
+ 500
+)
+
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@@ -44,20 +56,6 @@ log() {
log "dl_dir: $dl_dir"
-if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
- log "stage -1: Download LM"
- # We assume that you have installed the git-lfs, if not, you could install it
- # using: `sudo apt-get install git-lfs && git-lfs install`
- git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
-
- if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
- git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
- pushd $dl_dir/lm
- git lfs pull --include "3-gram.unpruned.arpa"
- popd
- fi
-fi
-
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: Download data"
@@ -131,7 +129,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
lang_phone_dir=data/lang_phone
-lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
mkdir -p $lang_phone_dir
@@ -180,39 +177,194 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi
fi
+lang_char_dir=data/lang_char
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
mkdir -p $lang_char_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
- cp $lang_phone_dir/words.txt $lang_char_dir
+
+ # The transcripts in training set, generated in stage 5
+ cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
- cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > $lang_char_dir/text
+ cut -d " " -f 2- > $lang_char_dir/text
+
+ (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \
+ > $lang_char_dir/words.txt
+
+ cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
+ | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt
+
+ num_lines=$(< $lang_char_dir/words.txt wc -l)
+ (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \
+ >> $lang_char_dir/words.txt
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
- ./local/prepare_char.py
+ ./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
- log "Stage 7: Prepare G"
- # We assume you have install kaldilm, if not, please install
- # it using: pip install kaldilm
+ log "Stage 7: Prepare Byte BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bbpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ cp $lang_char_dir/words.txt $lang_dir
+ cp $lang_char_dir/text $lang_dir
+
+ if [ ! -f $lang_dir/bbpe.model ]; then
+ ./local/train_bbpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/text
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
+ fi
+ done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Prepare G"
mkdir -p data/lm
- if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+
+ # Train LM on transcripts
+ if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
+ python3 ./shared/make_kn_lm.py \
+ -ngram-order 3 \
+ -text $lang_char_dir/transcript_words.txt \
+ -lm data/lm/3-gram.unpruned.arpa
+ fi
+
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+ if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="$lang_phone_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
- $dl_dir/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
+ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_phone.fst.txt
+
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_char_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram_char.fst.txt
fi
fi
-if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
- log "Stage 8: Compile HLG"
- ./local/compile_hlg.py --lang-dir $lang_phone_dir
- ./local/compile_hlg.py --lang-dir $lang_char_dir
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Compile LG & HLG"
+ ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
+ ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bbpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char
+ done
+
+ ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone
+ ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bbpe_${vocab_size}
+ ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char
+ done
+fi
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Generate LM training data"
+
+ log "Processing char based data"
+ out_dir=data/lm_training_char
+ mkdir -p $out_dir $dl_dir/lm
+
+ if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
+ cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
+ fi
+
+ # training words
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/aishell-train-word.txt \
+ --lm-archive $out_dir/lm_data.pt
+
+ # valid words
+ if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
+ aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
+ aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
+ find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
+ awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
+ cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
+ fi
+
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/aishell-valid-word.txt \
+ --lm-archive $out_dir/lm_data_valid.pt
+
+ # test words
+ if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
+ aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
+ aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
+ find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
+ awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
+ cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
+ fi
+
+ ./local/prepare_char_lm_training_data.py \
+ --lang-char data/lang_char \
+ --lm-data $dl_dir/lm/aishell-test-word.txt \
+ --lm-archive $out_dir/lm_data_test.pt
+fi
+
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Sort LM training data"
+ # Sort LM training data by sentence length in descending order
+ # for ease of training.
+ #
+ # Sentence length equals to the number of tokens
+ # in a sentence.
+
+ out_dir=data/lm_training_char
+ mkdir -p $out_dir
+ ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data.pt \
+ --out-lm-data $out_dir/sorted_lm_data.pt \
+ --out-statistics $out_dir/statistics.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data_valid.pt \
+ --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+ --out-statistics $out_dir/statistics-valid.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data_test.pt \
+ --out-lm-data $out_dir/sorted_lm_data-test.pt \
+ --out-statistics $out_dir/statistics-test.txt
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 11: Train RNN LM model"
+ python ../../../icefall/rnn_lm/train.py \
+ --start-epoch 0 \
+ --world-size 1 \
+ --num-epochs 20 \
+ --use-fp16 0 \
+ --embedding-dim 512 \
+ --hidden-dim 512 \
+ --num-layers 2 \
+ --batch-size 400 \
+ --exp-dir rnnlm_char/exp \
+ --lm-data $out_dir/sorted_lm_data.pt \
+ --lm-data-valid $out_dir/sorted_lm_data-valid.pt \
+ --vocab-size 4336 \
+ --master-port 12345
fi
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index a12934d55..fb6c7c481 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -76,11 +76,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
-from icefall.checkpoint import (
- average_checkpoints,
- find_checkpoints,
- load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@@ -188,8 +184,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@@ -263,10 +256,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -387,9 +377,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -400,24 +388,18 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -427,10 +409,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@@ -473,9 +452,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -504,8 +481,7 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
- f"No checkpoints found for"
- f" --iter {params.iter}, --avg {params.avg}"
+ f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index feababdd2..2ce5cfe69 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -50,11 +50,7 @@ from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
-from icefall.checkpoint import (
- average_checkpoints,
- find_checkpoints,
- load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@@ -120,8 +116,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
@@ -157,8 +152,7 @@ def main():
]
if len(filenames) == 0:
raise ValueError(
- f"No checkpoints found for"
- f" --iter {params.iter}, --avg {params.avg}"
+ f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
@@ -191,9 +185,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
- filename = (
- params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
- )
+ filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@@ -201,17 +193,14 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
- params.exp_dir
- / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+ params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 3c38e5db7..82c10f129 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -196,10 +195,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -256,13 +254,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@@ -310,9 +304,7 @@ def main():
beam=params.beam_size,
)
else:
- raise ValueError(
- f"Unsupported decoding method: {params.method}"
- )
+ raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@@ -329,9 +321,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
index 97d892754..d08908238 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py
@@ -49,7 +49,6 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from decoder import Decoder
@@ -75,9 +74,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@@ -203,8 +200,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
- help="The initial learning rate. This value should not need "
- "to be changed.",
+ help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
@@ -227,8 +223,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -251,8 +246,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -561,11 +555,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
- device = (
- model.device
- if isinstance(model, DDP)
- else next(model.parameters()).device
- )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@@ -593,23 +583,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -725,9 +708,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@@ -891,7 +872,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@@ -1029,9 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index d159e420b..27c64efaa 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -54,6 +54,40 @@ Usage:
--beam 4 \
--max-contexts 4 \
--max-states 8
+
+(5) modified beam search (with LM shallow fusion)
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search_lm_shallow_fusion \
+ --beam-size 4 \
+ --lm-type rnn \
+ --lm-scale 0.3 \
+ --lm-exp-dir /path/to/LM \
+ --rnn-lm-epoch 99 \
+ --rnn-lm-avg 1 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1
+
+(6) modified beam search with LM shallow fusion + LODR
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --max-duration 600 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --decoding-method modified_beam_search_LODR \
+ --beam-size 4 \
+ --lm-type rnn \
+ --lm-scale 0.48 \
+ --lm-exp-dir /path/to/LM \
+ --rnn-lm-epoch 99 \
+ --rnn-lm-avg 1 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1
+ --tokens-ngram 2 \
+ --ngram-lm-scale -0.28 \
"""
@@ -74,9 +108,12 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
)
from train import add_model_arguments, get_params, get_transducer_model
+from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@@ -202,8 +239,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -213,6 +249,60 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""Token Ngram used for rescoring.
+ Used only when the decoding method is
+ modified_beam_search_ngram_rescoring""",
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="""ID of the backoff symbol.
+ Used only when the decoding method is
+ modified_beam_search_ngram_rescoring""",
+ )
+
add_model_arguments(parser)
return parser
@@ -224,6 +314,9 @@ def decode_one_batch(
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
+ ngram_lm: Optional[NgramLm] = None,
+ ngram_lm_scale: float = 1.0,
+ LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@@ -263,9 +356,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@@ -277,10 +368,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -293,6 +381,24 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ )
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
@@ -340,6 +446,9 @@ def decode_dataset(
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
+ ngram_lm: Optional[NgramLm] = None,
+ ngram_lm_scale: float = 1.0,
+ LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@@ -385,6 +494,9 @@ def decode_dataset(
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ LM=LM,
)
for name, hyps in hyps_dict.items():
@@ -401,9 +513,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -414,24 +524,18 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -441,10 +545,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@@ -462,6 +563,7 @@ def save_results(
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@@ -475,6 +577,8 @@ def main():
"beam_search",
"fast_beam_search",
"modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@@ -488,9 +592,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -498,6 +600,19 @@ def main():
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
+ if "ngram" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ if params.use_shallow_fusion:
+ if params.lm_type == "rnn":
+ params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
+ elif params.lm_type == "transformer":
+ params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@@ -518,9 +633,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -551,9 +666,9 @@ def main():
)
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -607,6 +722,35 @@ def main():
else:
decoding_graph = None
+ # only load N-gram LM when needed
+ if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
+ lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"lm filename: {lm_filename}")
+ ngram_lm = NgramLm(
+ lm_filename,
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ # only load the neural network LM if doing shallow fusion
+ if params.use_shallow_fusion:
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+
+ else:
+ LM = None
+
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@@ -629,6 +773,9 @@ def main():
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ LM=LM,
)
save_results(
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
deleted file mode 120000
index bcd4abc2f..000000000
--- a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
+++ /dev/null
@@ -1 +0,0 @@
-/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 566902a85..723414167 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -48,6 +48,7 @@ import logging
from pathlib import Path
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@@ -132,8 +133,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
@@ -166,9 +166,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -195,9 +195,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -245,6 +245,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
@@ -252,9 +253,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
- filename = (
- params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
- )
+ filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
@@ -262,17 +261,14 @@ def main():
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
- params.exp_dir
- / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+ params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py
new file mode 120000
index 000000000..557e18aa1
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
index e150e8230..a4dda0d6d 100644
--- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -84,9 +84,7 @@ class Transducer(nn.Module):
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
- self.simple_am_proj = ScaledLinear(
- encoder_dim, vocab_size, initial_speed=0.5
- )
+ self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None:
@@ -179,9 +177,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
- boundary = torch.zeros(
- (x.size(0), 4), dtype=torch.int64, device=x.device
- )
+ boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index 04a0a882a..ead393e6e 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -196,10 +195,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -257,13 +255,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@@ -311,9 +305,7 @@ def main():
beam=params.beam_size,
)
else:
- raise ValueError(
- f"Unsupported decoding method: {params.method}"
- )
+ raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@@ -330,9 +322,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
index feaef5cf6..62e67530d 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -96,9 +96,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@@ -224,8 +222,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
- help="The initial learning rate. This value should not need "
- "to be changed.",
+ help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
@@ -248,8 +245,7 @@ def get_parser():
"--context-size",
type=int,
default=1,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -272,8 +268,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -635,11 +630,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
- device = (
- model.device
- if isinstance(model, DDP)
- else next(model.parameters()).device
- )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@@ -670,23 +661,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -824,9 +808,7 @@ def train_one_epoch(
)
# summary stats
if datatang_train_dl is not None:
- tot_loss = (
- tot_loss * (1 - 1 / params.reset_interval)
- ) + loss_info
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if aishell:
aishell_tot_loss = (
@@ -847,9 +829,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@@ -892,9 +872,7 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0]
if datatang_train_dl is not None:
datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
- tot_loss_str = (
- f"tot_loss[{tot_loss}], batch size: {batch_size}, "
- )
+ tot_loss_str = f"tot_loss[{tot_loss}], batch size: {batch_size}, "
else:
tot_loss_str = ""
datatang_str = ""
@@ -1067,7 +1045,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@@ -1076,9 +1054,7 @@ def run(rank, world_size, args):
train_cuts = filter_short_and_long_utterances(train_cuts)
if args.enable_musan:
- cuts_musan = load_manifest(
- Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
else:
cuts_musan = None
@@ -1093,9 +1069,7 @@ def run(rank, world_size, args):
if params.datatang_prob > 0:
datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
train_datatang_cuts = datatang.train_cuts()
- train_datatang_cuts = filter_short_and_long_utterances(
- train_datatang_cuts
- )
+ train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
datatang_train_dl = asr_datamodule.train_dataloaders(
train_datatang_cuts,
@@ -1249,9 +1223,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
new file mode 100755
index 000000000..fcb0ebc4e
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
@@ -0,0 +1,819 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7_bbpe/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import (
+ LmScorer,
+ NgramLm,
+ byte_encode,
+ smart_byte_decode,
+ tokenize_by_CJK_char,
+)
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the byte BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bbpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.25,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ subtract_ilme=True,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ ref_texts = []
+ for tx in supervisions["text"]:
+ ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
+
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(ref_texts),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+ key += f"_ilme_scale_{params.ilme_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network LM, used during shallow fusion
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ params.suffix += f"-ilme-scale-{params.ilme_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+
+ test_cuts = aishell.test_cuts()
+ dev_cuts = aishell.valid_cuts()
+
+ test_dl = aishell.test_dataloaders(test_cuts)
+ dev_dl = aishell.test_dataloaders(dev_cuts)
+
+ test_sets = ["test", "dev"]
+ test_dls = [test_dl, dev_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
new file mode 100755
index 000000000..4e82b45d3
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7_bbpe/export.py \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7_bbpe/export.py \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --epoch 20 \
+ --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7_bbpe/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell/ASR
+ ./pruned_transducer_stateless7_bbpe/decode.py \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_500/bbpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
+ # You will find the pre-trained model in icefall_asr_aishell_pruned_transducer_stateless7_bbpe/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_bbpe/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named cpu_jit.pt
+
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
new file mode 100755
index 000000000..0c43bf74b
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./pruned_transducer_stateless7_bbpe/export.py \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --epoch 49 \
+ --avg 28 \
+ --jit 1
+
+Usage of this script:
+
+./pruned_transducer_stateless7_bbpe/jit_pretrained.py \
+ --nn-model-filename ./pruned_transducer_stateless7_bbpe/exp/cpu_jit.pt \
+ --bpe-model ./data/lang_bbpe_500/bbpe.model \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model-filename",
+ type=str,
+ required=True,
+ help="Path to the torchscript model cpu_jit.pt",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to bpe.model.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def greedy_search(
+ model: torch.jit.ScriptModule,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ A 3-D tensor of shape (N, T, C)
+ encoder_out_lens:
+ A 1-D tensor of shape (N,).
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ assert encoder_out.ndim == 3
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ device = encoder_out.device
+ blank_id = 0 # hard-code to 0
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ context_size = model.decoder.context_size
+ hyps = [[blank_id] * context_size for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ device=device,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ ).squeeze(1)
+
+ offset = 0
+ for batch_size in batch_size_list:
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = packed_encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out
+ # current_encoder_out's shape: (batch_size, encoder_out_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ )
+ # logits'shape (batch_size, vocab_size)
+
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v != blank_id:
+ hyps[i].append(v)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ device=device,
+ dtype=torch.int64,
+ )
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ )
+ decoder_out = decoder_out.squeeze(1)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = torch.jit.load(args.nn_model_filename)
+
+ model.eval()
+
+ model.to(device)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(args.bpe_model)
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features,
+ x_lens=feature_lengths,
+ )
+
+ hyps = greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ s = "\n"
+ for filename, hyp in zip(args.sound_files, hyps):
+ words = smart_byte_decode(sp.decode(hyp))
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
new file mode 100755
index 000000000..ea5bda4db
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7_bbpe/export.py \
+ --exp-dir ./pruned_transducer_stateless7_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --epoch 48 \
+ --avg 29
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7_bbpe/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
+ --bpe-model ./data/lang_bbpe_500/bbpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./pruned_transducer_stateless7_bbpe/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
+ --bpe-model ./data/lang_bbpe_500/bbpe.model \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./pruned_transducer_stateless7_bbpe/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt \
+ --bpe-model ./data/lang_bbpe_500/bbpe.model \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+Note: ./pruned_transducer_stateless7_bbpe/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7_bbpe/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import smart_byte_decode
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to bpe.model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.method}"
+ if params.method == "beam_search":
+ msg += f" with beam size {params.beam_size}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
new file mode 100755
index 000000000..499badb14
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
@@ -0,0 +1,1261 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_bbpe/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 400
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_bbpe/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_bbpe/exp \
+ --max-duration 800
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import byte_encode, diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+ tokenize_by_CJK_char,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_bbpe/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the Byte BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 2000,
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ train_cuts = aishell.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 12.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_CJK_char(text))
+ c.supervisions[0].text = text
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = aishell.valid_cuts()
+ valid_cuts = valid_cuts.map(tokenize_and_encode_text)
+
+ valid_dl = aishell.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index d24ba6bb7..efb32336a 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -21,7 +21,7 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
-from typing import List
+from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
@@ -181,19 +181,24 @@ class AishellAsrDataModule:
"with training dataset. ",
)
- def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
+ def train_dataloaders(
+ self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(
- self.args.manifest_dir / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -215,9 +220,7 @@ class AishellAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -260,9 +263,7 @@ class AishellAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@@ -285,6 +286,10 @@ class AishellAsrDataModule:
)
logging.info("About to create train dataloader")
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
train_dl = DataLoader(
train,
sampler=train_sampler,
@@ -308,9 +313,7 @@ class AishellAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
@@ -335,7 +338,7 @@ class AishellAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
- logging.debug("About to create test dataset")
+ logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
@@ -366,13 +369,9 @@ class AishellAsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
- )
+ return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
- )
+ return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 66b734fc4..824ca2a92 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -265,9 +265,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -289,9 +287,7 @@ def save_results(
# We compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer
@@ -335,9 +331,7 @@ def main():
logging.info(f"device: {device}")
- HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
- )
+ HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
@@ -362,9 +356,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
- torch.save(
- {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
- )
+ torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
model.to(device)
model.eval()
@@ -392,9 +384,7 @@ def main():
lexicon=lexicon,
)
- save_results(
- params=params, test_set_name=test_set, results_dict=results_dict
- )
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
index 5e04c11b4..1731e1ebe 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/model.py
@@ -66,10 +66,7 @@ class TdnnLstm(nn.Module):
nn.BatchNorm1d(num_features=500, affine=False),
)
self.lstms = nn.ModuleList(
- [
- nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
- for _ in range(5)
- ]
+ [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)]
)
self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 9bd810809..7e7213501 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -53,9 +53,7 @@ def get_parser():
help="Path to words.txt",
)
- parser.add_argument(
- "--HLG", type=str, required=True, help="Path to HLG.pt."
- )
+ parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.")
parser.add_argument(
"--method",
@@ -112,10 +110,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -173,9 +170,7 @@ def main():
logging.info("Decoding started")
features = fbank(waves)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = features.permute(0, 2, 1) # now features is [N, C, T]
with torch.no_grad():
@@ -219,9 +214,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index 7619b0551..e574cf89b 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -49,12 +49,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
-from icefall.utils import (
- AttributeDict,
- encode_supervisions,
- setup_logger,
- str2bool,
-)
+from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool
def get_parser():
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index 9ed9b2ad1..de0a8d0f5 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -47,9 +47,9 @@ def greedy_search(
device = model.device
- decoder_input = torch.tensor(
- [blank_id] * context_size, device=device
- ).reshape(1, context_size)
+ decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+ 1, context_size
+ )
decoder_out = model.decoder(decoder_input, need_pad=False)
@@ -81,9 +81,9 @@ def greedy_search(
y = logits.argmax().item()
if y != blank_id:
hyp.append(y)
- decoder_input = torch.tensor(
- [hyp[-context_size:]], device=device
- ).reshape(1, context_size)
+ decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+ 1, context_size
+ )
decoder_out = model.decoder(decoder_input, need_pad=False)
@@ -157,9 +157,7 @@ class HypothesisList(object):
"""
if length_norm:
- return max(
- self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
- )
+ return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
@@ -246,9 +244,9 @@ def beam_search(
device = model.device
- decoder_input = torch.tensor(
- [blank_id] * context_size, device=device
- ).reshape(1, context_size)
+ decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
+ 1, context_size
+ )
decoder_out = model.decoder(decoder_input, need_pad=False)
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index 64114253d..78424aea2 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -155,9 +155,7 @@ class ConformerEncoderLayer(nn.Module):
normalize_before: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
- self.self_attn = RelPositionMultiheadAttention(
- d_model, nhead, dropout=0.0
- )
+ self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
@@ -175,18 +173,14 @@ class ConformerEncoderLayer(nn.Module):
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
- self.norm_ff_macaron = nn.LayerNorm(
- d_model
- ) # for the macaron style FNN module
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.ff_scale = 0.5
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
- self.norm_final = nn.LayerNorm(
- d_model
- ) # for the final output of the block
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.dropout = nn.Dropout(dropout)
@@ -220,9 +214,7 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
- src = residual + self.ff_scale * self.dropout(
- self.feed_forward_macaron(src)
- )
+ src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src))
if not self.normalize_before:
src = self.norm_ff_macaron(src)
@@ -341,9 +333,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""
- def __init__(
- self, d_model: int, dropout_rate: float, max_len: int = 5000
- ) -> None:
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
@@ -359,9 +349,7 @@ class RelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
- if self.pe.dtype != x.dtype or str(self.pe.device) != str(
- x.device
- ):
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
@@ -631,9 +619,9 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
- q, k, v = nn.functional.linear(
- query, in_proj_weight, in_proj_bias
- ).chunk(3, dim=-1)
+ q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
+ 3, dim=-1
+ )
elif torch.equal(key, value):
# encoder-decoder attention
@@ -701,31 +689,22 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
- raise RuntimeError(
- "The size of the 2D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
- raise RuntimeError(
- "The size of the 3D attn_mask is not correct."
- )
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
- "attn_mask's dimension {} is not supported".format(
- attn_mask.dim()
- )
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
- if (
- key_padding_mask is not None
- and key_padding_mask.dtype == torch.uint8
- ):
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
@@ -764,9 +743,7 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
- matrix_ac = torch.matmul(
- q_with_bias_u, k
- ) # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
@@ -778,9 +755,7 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
- attn_output_weights = attn_output_weights.view(
- bsz * num_heads, tgt_len, -1
- )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
@@ -814,13 +789,9 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = (
- attn_output.transpose(0, 1)
- .contiguous()
- .view(tgt_len, bsz, embed_dim)
- )
- attn_output = nn.functional.linear(
- attn_output, out_proj_weight, out_proj_bias
+ attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
+ attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
@@ -843,9 +814,7 @@ class ConvolutionModule(nn.Module):
"""
- def __init__(
- self, channels: int, kernel_size: int, bias: bool = True
- ) -> None:
+ def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 780b0c4bb..d23f4f883 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -99,8 +99,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
batch_size = encoder_out.size(0)
@@ -248,9 +245,7 @@ def decode_one_batch(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
- raise ValueError(
- f"Unsupported decoding method: {params.decoding_method}"
- )
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
@@ -319,9 +314,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -332,23 +325,17 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -358,10 +345,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@@ -430,9 +414,7 @@ def main():
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
- torch.save(
- {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
- )
+ torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt")
return
model.to(device)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index c2c6552a9..70e9e6c96 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -86,9 +86,7 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
- embedding_out = F.pad(
- embedding_out, pad=(self.context_size - 1, 0)
- )
+ embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 4c6519b96..01de5d772 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -110,8 +110,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
return parser
@@ -243,9 +242,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 994305fc1..591bbe44f 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -103,9 +103,7 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
- boundary = torch.zeros(
- (x.size(0), 4), dtype=torch.int64, device=x.device
- )
+ boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index db89c4d67..40f430e13 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -117,8 +117,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -211,10 +210,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -273,9 +271,7 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
@@ -319,9 +315,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index d54157709..62ffff473 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -126,8 +126,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -389,9 +388,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -504,9 +501,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -625,9 +620,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
+ tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index e851dcc32..b3ff153c1 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -250,9 +250,7 @@ def _get_activation_fn(activation: str):
elif activation == "gelu":
return nn.functional.gelu
- raise RuntimeError(
- "activation should be relu/gelu, not {}".format(activation)
- )
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index 838e53658..5d49d7338 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -29,10 +29,7 @@ from lhotse.dataset import (
K2SpeechRecognitionDataset,
SpecAugment,
)
-from lhotse.dataset.input_strategies import (
- OnTheFlyFeatures,
- PrecomputedFeatures,
-)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@@ -162,9 +159,7 @@ class AsrDataModule:
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -173,9 +168,7 @@ class AsrDataModule:
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -252,9 +245,7 @@ class AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index ea3f94fd8..d164b6890 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -170,8 +170,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -227,9 +226,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@@ -241,10 +238,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -365,9 +359,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -378,24 +370,18 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -405,10 +391,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@@ -448,9 +431,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index 3bd2ceb11..c1081c32b 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -109,8 +109,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
return parser
@@ -241,9 +240,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index a95a4bc52..5d8ca2e11 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -194,10 +193,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@@ -308,9 +302,7 @@ def main():
beam=params.beam_size,
)
else:
- raise ValueError(
- f"Unsupported decoding method: {params.method}"
- )
+ raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@@ -327,9 +319,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
index 225d0d709..8fb7d1e49 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -149,8 +149,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -168,8 +167,7 @@ def get_parser():
"--datatang-prob",
type=float,
default=0.2,
- help="The probability to select a batch from the "
- "aidatatang_200zh dataset",
+ help="The probability to select a batch from the aidatatang_200zh dataset",
)
return parser
@@ -449,9 +447,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -605,9 +601,7 @@ def train_one_epoch(
f"train/current_{prefix}_",
params.batch_idx_train,
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
aishell_tot_loss.write_summary(
tb_writer, "train/aishell_tot_", params.batch_idx_train
)
@@ -735,9 +729,7 @@ def run(rank, world_size, args):
train_datatang_cuts = train_datatang_cuts.repeat(times=None)
if args.enable_musan:
- cuts_musan = load_manifest(
- Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz")
else:
cuts_musan = None
@@ -776,9 +768,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
+ tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 65fcda873..0a7d87fe8 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -171,8 +171,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -231,9 +230,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
@@ -245,10 +242,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -369,9 +363,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -382,24 +374,18 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
# we compute CER for aishell dataset.
results_char = []
for res in results:
- results_char.append(
- (res[0], list("".join(res[1])), list("".join(res[2])))
- )
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@@ -409,10 +395,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@@ -452,9 +435,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index 11335a834..3e14ad69c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -109,8 +109,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
return parser
@@ -241,9 +240,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 262e822c2..9e4459247 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -165,8 +165,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -194,10 +193,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -254,13 +252,9 @@ def main():
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lens)
num_waves = encoder_out.size(0)
hyp_list = []
@@ -308,9 +302,7 @@ def main():
beam=params.beam_size,
)
else:
- raise ValueError(
- f"Unsupported decoding method: {params.method}"
- )
+ raise ValueError(f"Unsupported decoding method: {params.method}")
hyp_list.append(hyp)
hyps = []
@@ -327,9 +319,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
index d3ffccafa..5f116f2bd 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/train.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -142,8 +142,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -414,9 +413,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -529,9 +526,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -657,9 +652,7 @@ def run(rank, world_size, args):
cur_lr = optimizer._rate
if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
+ tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index d8d3622bd..ec0c584ca 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -83,9 +83,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@@ -111,9 +109,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh
index 06810bfdd..3e8e840ab 100755
--- a/egs/aishell2/ASR/prepare.sh
+++ b/egs/aishell2/ASR/prepare.sh
@@ -1,5 +1,8 @@
#!/usr/bin/env bash
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
set -eou pipefail
nj=30
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
old mode 100755
new mode 100644
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
old mode 100755
new mode 100644
index b7a21f579..0f383a244
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -216,13 +216,9 @@ class AiShell2AsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(
- self.args.manifest_dir / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -244,9 +240,7 @@ class AiShell2AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -290,9 +284,7 @@ class AiShell2AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@@ -348,9 +340,7 @@ class AiShell2AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
@@ -406,9 +396,7 @@ class AiShell2AsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
- return load_manifest_lazy(
- self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
- )
+ return load_manifest_lazy(self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
index 915737f4a..9e44b4e34 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -269,8 +269,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -348,9 +347,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@@ -409,10 +406,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -538,9 +532,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -551,18 +543,14 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
@@ -572,10 +560,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@@ -625,9 +610,7 @@ def main():
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -661,9 +644,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -690,9 +673,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -749,9 +732,7 @@ def main():
)
decoding_graph.scores *= params.ngram_lm_scale
else:
- decoding_graph = k2.trivial_graph(
- params.vocab_size - 1, device=device
- )
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index bc7bd71cb..8a5be94d0 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -133,8 +133,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
@@ -167,9 +166,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -196,9 +195,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -266,9 +265,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index 09de1bece..bc3ae7abf 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -159,8 +159,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -191,10 +190,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -254,15 +252,11 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lengths
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
@@ -334,9 +328,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
index 838a0497f..74bf68ccb 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -92,9 +92,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@@ -220,8 +218,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
- help="The initial learning rate. This value should not need "
- "to be changed.",
+ help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
@@ -244,8 +241,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -268,8 +264,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -603,11 +598,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
- device = (
- model.device
- if isinstance(model, DDP)
- else next(model.parameters()).device
- )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@@ -636,23 +627,16 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -771,9 +755,7 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@@ -829,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -939,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@@ -1104,9 +1084,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
- display_and_save_batch(
- batch, params=params, graph_compiler=graph_compiler
- )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index 3f50d9e3e..400c406f0 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -85,9 +85,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
@@ -120,9 +118,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/aishell4/ASR/local/prepare_char.py
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
- pieces = [
- token2id[i] if i in token2id else token2id[""] for i in pieces
- ]
+ pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
-def generate_lexicon(
- token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/aishell4/ASR/local/prepare_lang.py
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir", type=str, help="The lang dir, data/lang_phone"
- )
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/aishell4/ASR/local/test_prepare_lang.py
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
- fsa_disambig = lexicon_to_fst(
- lexicon_disambig, phone2id=phone2id, word2id=word2id
- )
+ fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/aishell4/ASR/local/text2token.py
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
- parser.add_argument(
- "--space", default="", type=str, help="space symbol"
- )
+ parser.add_argument("--space", default="", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., etc.",
)
- parser.add_argument(
- "text", type=str, default=False, nargs="?", help="input text"
- )
+ parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
- token_table[txt] if txt in token_table else oov_id
- for txt in text
+ token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
- f = codecs.getreader("utf-8")(
- sys.stdin if is_python2 else sys.stdin.buffer
- )
+ f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
index c351e3964..cb2b73a3e 100755
--- a/egs/aishell4/ASR/prepare.sh
+++ b/egs/aishell4/ASR/prepare.sh
@@ -1,5 +1,8 @@
#!/usr/bin/env bash
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
set -eou pipefail
stage=-1
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 7aa53ddda..d980a857f 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -222,17 +222,13 @@ class Aishell4AsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(
- self.args.manifest_dir / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -254,9 +250,7 @@ class Aishell4AsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -300,9 +294,7 @@ class Aishell4AsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@@ -359,9 +351,7 @@ class Aishell4AsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
index 14e44c7d9..068e2749a 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -201,8 +201,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -260,9 +259,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@@ -277,10 +274,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -401,9 +395,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -414,18 +406,14 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
@@ -435,10 +423,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@@ -480,9 +465,7 @@ def main():
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
- params.suffix += (
- f"-{params.decoding_method}-beam-size-{params.beam_size}"
- )
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@@ -510,9 +493,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -543,9 +526,9 @@ def main():
)
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index 993341131..bf9856c60 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -136,8 +136,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
@@ -169,9 +168,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -202,9 +201,9 @@ def main():
)
else:
if params.iter > 0:
- filenames = find_checkpoints(
- params.exp_dir, iteration=-params.iter
- )[: params.avg + 1]
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@@ -276,9 +275,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index 1fa893637..ee898c303 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -172,8 +172,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -204,10 +203,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -266,15 +264,11 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
- encoder_out, encoder_out_lens = model.encoder(
- x=features, x_lens=feature_lengths
- )
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
@@ -306,10 +300,7 @@ def main():
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -350,9 +341,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
index 0a48b9059..d7c69f226 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -85,9 +85,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def add_model_arguments(parser: argparse.ArgumentParser):
@@ -213,8 +211,7 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
- help="The initial learning rate. This value should not need "
- "to be changed.",
+ help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
@@ -237,8 +234,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -261,8 +257,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -599,11 +594,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
- device = (
- model.device
- if isinstance(model, DDP)
- else next(model.parameters()).device
- )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@@ -633,22 +624,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -827,9 +811,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -937,7 +919,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index af926aa53..96115a230 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -84,9 +84,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
)
if "train" in partition:
cut_set = (
- cut_set
- + cut_set.perturb_speed(0.9)
- + cut_set.perturb_speed(1.1)
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
cur_num_jobs = num_jobs if ex is None else 80
cur_num_jobs = min(cur_num_jobs, len(cut_set))
@@ -121,9 +119,7 @@ def get_args():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
index d9e47d17a..6b440dfb3 100755
--- a/egs/alimeeting/ASR/local/prepare_char.py
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -86,9 +86,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
- pieces = [
- token2id[i] if i in token2id else token2id[""] for i in pieces
- ]
+ pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@@ -142,9 +140,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
return False
-def generate_lexicon(
- token_sym_table: Dict[str, int], words: List[str]
-) -> Lexicon:
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
index e5ae89ec4..c8cf9b881 100755
--- a/egs/alimeeting/ASR/local/prepare_lang.py
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -317,9 +317,7 @@ def lexicon_to_fst(
def get_args():
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir", type=str, help="The lang dir, data/lang_phone"
- )
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
return parser.parse_args()
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
index d4cf62bba..74e025ad7 100755
--- a/egs/alimeeting/ASR/local/test_prepare_lang.py
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -88,9 +88,7 @@ def test_read_lexicon(filename: str):
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
- fsa_disambig = lexicon_to_fst(
- lexicon_disambig, phone2id=phone2id, word2id=word2id
- )
+ fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
index 7c1019aa8..27b904fc8 100644
--- a/egs/alimeeting/ASR/local/text2segments.py
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -30,8 +30,8 @@ with word segmenting:
import argparse
-import paddle
import jieba
+import paddle
from tqdm import tqdm
paddle.enable_static()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
index 71be2a613..85047c367 100755
--- a/egs/alimeeting/ASR/local/text2token.py
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -56,9 +56,7 @@ def get_parser():
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
- parser.add_argument(
- "--space", default="", type=str, help="space symbol"
- )
+ parser.add_argument("--space", default="", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
@@ -66,9 +64,7 @@ def get_parser():
type=str,
help="list of non-linguistic symobles, e.g., etc.",
)
- parser.add_argument(
- "text", type=str, default=False, nargs="?", help="input text"
- )
+ parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
@@ -108,8 +104,7 @@ def token2id(
if token_type == "lazy_pinyin":
text = lazy_pinyin(chars_list)
sub_ids = [
- token_table[txt] if txt in token_table else oov_id
- for txt in text
+ token_table[txt] if txt in token_table else oov_id for txt in text
]
ids.append(sub_ids)
else: # token_type = "pinyin"
@@ -135,9 +130,7 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
- f = codecs.getreader("utf-8")(
- sys.stdin if is_python2 else sys.stdin.buffer
- )
+ f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index 17224bb68..604cc92c6 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -1,5 +1,8 @@
#!/usr/bin/env bash
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
set -eou pipefail
stage=-1
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index bf6faad7a..a9a4675a9 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -205,17 +205,13 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(
- self.args.manifest_dir / "musan_cuts.jsonl.gz"
- )
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
transforms.append(
- CutMix(
- cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
- )
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -237,9 +233,7 @@ class AlimeetingAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
- logging.info(
- f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
- )
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@@ -282,9 +276,7 @@ class AlimeetingAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@@ -341,9 +333,7 @@ class AlimeetingAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(
- Fbank(FbankConfig(num_mel_bins=80))
- ),
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
index 6358fe970..6c170c392 100755
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -70,11 +70,7 @@ from beam_search import (
from lhotse.cut import Cut
from train import get_params, get_transducer_model
-from icefall.checkpoint import (
- average_checkpoints,
- find_checkpoints,
- load_checkpoint,
-)
+from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@@ -193,8 +189,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@@ -249,9 +244,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
- encoder_out, encoder_out_lens = model.encoder(
- x=feature, x_lens=feature_lens
- )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
@@ -266,10 +259,7 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -390,9 +380,7 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(
- f"batch {batch_str}, cuts processed until now is {num_cuts}"
- )
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@@ -403,18 +391,14 @@ def save_results(
):
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = (
- params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = (
- params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
@@ -424,10 +408,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = (
- params.res_dir
- / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
- )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@@ -563,8 +544,7 @@ def main():
)
dev_shards = [
- str(path)
- for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+ str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
@@ -574,8 +554,7 @@ def main():
)
test_shards = [
- str(path)
- for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+ str(path) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
@@ -588,9 +567,7 @@ def main():
return 1.0 <= c.duration
cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
- cuts_test_webdataset = cuts_test_webdataset.filter(
- remove_short_and_long_utt
- )
+ cuts_test_webdataset = cuts_test_webdataset.filter(remove_short_and_long_utt)
dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 8beec1b8a..8e5cc6075 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -103,8 +103,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
return parser
@@ -173,9 +172,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index 93b1e1f57..f5a0dd8c8 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -162,8 +162,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -193,10 +192,9 @@ def read_sound_files(
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
- assert sample_rate == expected_sample_rate, (
- f"expected sample rate: {expected_sample_rate}. "
- f"Given: {sample_rate}"
- )
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@@ -257,9 +255,7 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
- features = pad_sequence(
- features, batch_first=True, padding_value=math.log(1e-10)
- )
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
@@ -284,10 +280,7 @@ def main():
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
- elif (
- params.decoding_method == "greedy_search"
- and params.max_sym_per_frame == 1
- ):
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@@ -339,9 +332,7 @@ def main():
if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
index 81a0ede7f..e57b5c859 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -81,9 +81,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
-LRSchedulerType = Union[
- torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
-]
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
@@ -187,8 +185,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
- help="The context size in the decoder. 1 means bigram; "
- "2 means tri-gram",
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -211,8 +208,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
- help="The scale to smooth the loss with am (output of encoder network)"
- "part.",
+ help="The scale to smooth the loss with am (output of encoder network) part.",
)
parser.add_argument(
@@ -542,22 +538,15 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
- 0.0
- if warmup < 1.0
- else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
- )
- loss = (
- params.simple_loss_scale * simple_loss
- + pruned_loss_scale * pruned_loss
+ 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
+ loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- info["frames"] = (
- (feature_lens // params.subsampling_factor).sum().item()
- )
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -711,9 +700,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@@ -813,7 +800,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2 ** 22
+ 2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/alimeeting/ASR_v2/README.md b/egs/alimeeting/ASR_v2/README.md
new file mode 100644
index 000000000..f70327501
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/README.md
@@ -0,0 +1,38 @@
+
+# Introduction
+
+This recipe trains multi-domain ASR models for AliMeeting. By multi-domain, we mean that
+we train a single model on close-talk and far-field conditions. This recipe optionally
+uses [GSS]-based enhancement for far-field array microphone.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+This is different from `alimeeting/ASR` since that recipe trains a model only on the
+far-field audio. Additionally, we use text normalization here similar to the original
+M2MeT challenge, so the results should be more comparable to those from Table 4 of
+the [paper](https://arxiv.org/abs/2110.07393).
+
+The following additional packages need to be installed to run this recipe:
+* `pip install jieba`
+* `pip install paddlepaddle`
+* `pip install git+https://github.com/desh2608/gss.git`
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set | eval WER | test WER |
+|--------------------------|------------|---------|
+| IHM | 9.58 | 11.53 |
+| SDM | 23.37 | 25.85 |
+| MDM (GSS-enhanced) | 11.82 | 14.22 |
+
+See [RESULTS](/egs/alimeeting/ASR_v2/RESULTS.md) for details.
diff --git a/egs/alimeeting/ASR_v2/RESULTS.md b/egs/alimeeting/ASR_v2/RESULTS.md
new file mode 100644
index 000000000..15b24250d
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/RESULTS.md
@@ -0,0 +1,90 @@
+## Results (CER)
+
+#### 2022-12-09
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+| | eval | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 10.13 | 12.21 | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search | 9.58 | 11.53 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 9.92 | 12.07 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+| | eval | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 23.70 | 26.41 | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search | 23.37 | 25.85 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 23.60 | 26.38 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+| | eval | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 12.24 | 14.99 | --epoch 15 --avg 8 --max-duration 500 |
+| modified beam search | 11.82 | 14.22 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 12.30 | 14.98 | --epoch 15 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 15 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 300 \
+ --max-cuts 100 \
+ --prune-range 5 \
+ --lr-factor 5 \
+ --lm-scale 0.25 \
+ --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+Pretrained model is available at
+
+The tensorboard training log can be found at
+
diff --git a/egs/alimeeting/ASR_v2/local/__init__.py b/egs/alimeeting/ASR_v2/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
new file mode 100755
index 000000000..c6aa2ab36
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AliMeeting dataset.
+For the training data, we prepare IHM, reverberated IHM, SDM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ logging.info("Reading manifests")
+ manifests_ihm = read_manifests_if_cached(
+ dataset_parts=["train", "eval", "test"],
+ output_dir=src_dir,
+ prefix="alimeeting-ihm",
+ suffix="jsonl.gz",
+ )
+ manifests_sdm = read_manifests_if_cached(
+ dataset_parts=["train", "eval", "test"],
+ output_dir=src_dir,
+ prefix="alimeeting-sdm",
+ suffix="jsonl.gz",
+ )
+ # For GSS we already have cuts so we read them directly.
+ manifests_gss = read_manifests_if_cached(
+ dataset_parts=["train", "eval", "test"],
+ output_dir=src_dir,
+ prefix="alimeeting-gss",
+ suffix="jsonl.gz",
+ )
+
+ def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+ cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+ _ = cuts.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=storage_path,
+ manifest_path=manifest_path,
+ batch_duration=5000,
+ num_workers=8,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ logging.info(
+ "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+ )
+
+ logging.info("Processing train split IHM")
+ cuts_ihm = (
+ CutSet.from_manifests(**manifests_ihm["train"])
+ .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+ .modify_ids(lambda x: x + "-ihm")
+ )
+ _extract_feats(
+ cuts_ihm,
+ output_dir / "feats_train_ihm",
+ src_dir / "cuts_train_ihm.jsonl.gz",
+ )
+
+ logging.info("Processing train split IHM + reverberated IHM")
+ cuts_ihm_rvb = cuts_ihm.reverb_rir()
+ _extract_feats(
+ cuts_ihm_rvb,
+ output_dir / "feats_train_ihm_rvb",
+ src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+ )
+
+ logging.info("Processing train split SDM")
+ cuts_sdm = (
+ CutSet.from_manifests(**manifests_sdm["train"])
+ .trim_to_supervisions(keep_overlapping=False)
+ .modify_ids(lambda x: x + "-sdm")
+ )
+ _extract_feats(
+ cuts_sdm,
+ output_dir / "feats_train_sdm",
+ src_dir / "cuts_train_sdm.jsonl.gz",
+ )
+
+ logging.info("Processing train split GSS")
+ cuts_gss = (
+ CutSet.from_manifests(**manifests_gss["train"])
+ .trim_to_supervisions(keep_overlapping=False)
+ .modify_ids(lambda x: x + "-gss")
+ )
+ _extract_feats(
+ cuts_gss,
+ output_dir / "feats_train_gss",
+ src_dir / "cuts_train_gss.jsonl.gz",
+ )
+
+ logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+ for split in ["eval", "test"]:
+ logging.info(f"Processing {split} IHM")
+ cuts_ihm = (
+ CutSet.from_manifests(**manifests_ihm[split])
+ .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_ihm",
+ manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+ logging.info(f"Processing {split} SDM")
+ cuts_sdm = (
+ CutSet.from_manifests(**manifests_sdm[split])
+ .trim_to_supervisions(keep_overlapping=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_sdm",
+ manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+ logging.info(f"Processing {split} GSS")
+ cuts_gss = (
+ CutSet.from_manifests(**manifests_gss[split])
+ .trim_to_supervisions(keep_overlapping=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_gss",
+ manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_ami()
diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
new file mode 100644
index 000000000..f1512efa5
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AliMeeting GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s",
+ level=logging.INFO,
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+ parser.add_argument(
+ "manifests_dir",
+ type=Path,
+ help="Path to directory containing AliMeeting manifests.",
+ )
+ parser.add_argument(
+ "enhanced_dir",
+ type=Path,
+ help="Path to enhanced data directory.",
+ )
+ parser.add_argument(
+ "--num-jobs",
+ "-j",
+ type=int,
+ default=1,
+ help="Number of parallel jobs to run.",
+ )
+ parser.add_argument(
+ "--min-segment-duration",
+ "-d",
+ type=float,
+ default=0.0,
+ help="Minimum duration of a segment in seconds.",
+ )
+ return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+ """
+ Given a supervision (corresponding to original AMI recording), this function finds the
+ enhanced recording correspoding to the supervision, and returns this recording and
+ a new supervision whose start and end times are adjusted to match the enhanced recording.
+ """
+ file_name = Path(
+ f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+ )
+ save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+ if save_path.exists():
+ recording = Recording.from_file(save_path)
+ if recording.duration == 0:
+ logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+ return None
+
+ # Old supervision is wrt to the original recording, we create new supervision
+ # wrt to the enhanced segment
+ new_supervision = fastcopy(
+ supervision,
+ recording_id=recording.id,
+ start=0,
+ duration=recording.duration,
+ )
+ return recording, new_supervision
+ else:
+ logging.warning(f"{save_path} does not exist.")
+ return None
+
+
+def main(args):
+ # Get arguments
+ manifests_dir = args.manifests_dir
+ enhanced_dir = args.enhanced_dir
+
+ # Load manifests from cache if they exist (saves time)
+ manifests = read_manifests_if_cached(
+ dataset_parts=["train", "eval", "test"],
+ output_dir=manifests_dir,
+ prefix="alimeeting-sdm",
+ suffix="jsonl.gz",
+ )
+ if not manifests:
+ raise ValueError(
+ "AliMeeting SDM manifests not found in {}".format(manifests_dir)
+ )
+
+ with ThreadPoolExecutor(args.num_jobs) as ex:
+ for part in ["train", "eval", "test"]:
+ logging.info(f"Processing {part}...")
+ supervisions_orig = manifests[part]["supervisions"].filter(
+ lambda s: s.duration >= args.min_segment_duration
+ )
+ futures = []
+
+ for supervision in tqdm(
+ supervisions_orig,
+ desc="Distributing tasks",
+ ):
+ futures.append(
+ ex.submit(
+ find_recording_and_create_new_supervision,
+ enhanced_dir,
+ supervision,
+ )
+ )
+
+ recordings = []
+ supervisions = []
+ for future in tqdm(
+ futures,
+ total=len(futures),
+ desc="Processing tasks",
+ ):
+ result = future.result()
+ if result is not None:
+ recording, new_supervision = result
+ recordings.append(recording)
+ supervisions.append(new_supervision)
+
+ # Remove duplicates from the recordings
+ recordings_nodup = {}
+ for recording in recordings:
+ if recording.id not in recordings_nodup:
+ recordings_nodup[recording.id] = recording
+ else:
+ logging.warning("Recording {} is duplicated.".format(recording.id))
+ recordings = RecordingSet.from_recordings(recordings_nodup.values())
+ supervisions = SupervisionSet.from_segments(supervisions)
+
+ recordings, supervisions = fix_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ logging.info(f"Writing {part} enhanced manifests")
+ recordings.to_file(
+ manifests_dir / f"alimeeting-gss_recordings_{part}.jsonl.gz"
+ )
+ supervisions.to_file(
+ manifests_dir / f"alimeeting-gss_supervisions_{part}.jsonl.gz"
+ )
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
new file mode 100755
index 000000000..76db19832
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+ echo "Wrong #arguments ($#, expected 2)"
+ echo "Usage: local/prepare_alimeeting_gss.sh [options] "
+ echo "e.g. local/prepare_alimeeting_gss.sh data/manifests exp/ami_gss"
+ echo "main options (for others, see top of script file)"
+ echo " --nj # number of parallel jobs"
+ echo " --stage # stage to start running from"
+ exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+ log "Stage 1: Prepare cut sets"
+ for part in train eval test; do
+ lhotse cut simple \
+ -r $DATA_DIR/alimeeting-mdm_recordings_${part}.jsonl.gz \
+ -s $DATA_DIR/alimeeting-mdm_supervisions_${part}.jsonl.gz \
+ $EXP_DIR/cuts_${part}.jsonl.gz
+ done
+fi
+
+if [ $stage -le 2 ]; then
+ log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+ for part in train eval test; do
+ lhotse cut trim-to-supervisions --discard-overlapping \
+ $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+ done
+fi
+
+if [ $stage -le 3 ]; then
+ log "Stage 3: Split manifests for multi-GPU processing (optional)"
+ for part in train eval test; do
+ gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_${part}_split$nj
+ done
+fi
+
+if [ $stage -le 4 ]; then
+ log "Stage 4: Enhance train segments using GSS (requires GPU)"
+ # for train, we use smaller context and larger batches to speed-up processing
+ for JOB in $(seq $nj); do
+ gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+ --bss-iterations 10 \
+ --context-duration 5.0 \
+ --use-garbage-class \
+ --channels 0,1,2,3,4,5,6,7 \
+ --min-segment-length 0.05 \
+ --max-segment-length 25.0 \
+ --max-batch-duration 60.0 \
+ --num-buckets 4 \
+ --num-workers 4
+ done
+fi
+
+if [ $stage -le 5 ]; then
+ log "Stage 5: Enhance eval/test segments using GSS (using GPU)"
+ # for eval/test, we use larger context and smaller batches to get better quality
+ for part in eval test; do
+ for JOB in $(seq $nj); do
+ gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+ $EXP_DIR/enhanced \
+ --bss-iterations 10 \
+ --context-duration 15.0 \
+ --use-garbage-class \
+ --channels 0,1,2,3,4,5,6,7 \
+ --min-segment-length 0.05 \
+ --max-segment-length 16.0 \
+ --max-batch-duration 45.0 \
+ --num-buckets 4 \
+ --num-workers 4
+ done
+ done
+fi
+
+if [ $stage -le 6 ]; then
+ log "Stage 6: Prepare manifests for GSS-enhanced data"
+ python local/prepare_alimeeting_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/alimeeting/ASR_v2/local/prepare_char.py b/egs/alimeeting/ASR_v2/local/prepare_char.py
new file mode 120000
index 000000000..ee5dd34f1
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_char.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/prepare_words.py b/egs/alimeeting/ASR_v2/local/prepare_words.py
new file mode 120000
index 000000000..970bfd60c
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/prepare_words.py
@@ -0,0 +1 @@
+../../ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2segments.py b/egs/alimeeting/ASR_v2/local/text2segments.py
new file mode 120000
index 000000000..bf4547794
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2segments.py
@@ -0,0 +1 @@
+../../ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/local/text2token.py b/egs/alimeeting/ASR_v2/local/text2token.py
new file mode 120000
index 000000000..f6b8531b6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/local/text2token.py
@@ -0,0 +1 @@
+../../ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh
new file mode 100755
index 000000000..76a108771
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/prepare.sh
@@ -0,0 +1,125 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/alimeeting
+# This directory contains the following files downloaded from
+# https://openslr.org/62/
+#
+# - Train_Ali_far.tar.gz
+# - Train_Ali_near.tar.gz
+# - Test_Ali.tar.gz
+# - Eval_Ali.tar.gz
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then
+ lhotse download ali-meeting $dl_dir/alimeeting
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare alimeeting manifest"
+ # We assume that you have downloaded the alimeeting corpus
+ # to $dl_dir/alimeeting
+ for part in ihm sdm mdm; do
+ mkdir -p data/manifests/alimeeting
+ lhotse prepare ali-meeting --mic $part --save-mono --normalize-text m2met \
+ $dl_dir/alimeeting data/manifests
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+ log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+ # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+ local/prepare_alimeeting_gss.sh data/manifests exp/alimeeting_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ mkdir -p data/fbank
+ python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for alimeeting"
+ mkdir -p data/fbank
+ python local/compute_fbank_alimeeting.py
+ log "Combine features from train splits"
+ lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+ gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare char based lang"
+ lang_char_dir=data/lang_char
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ gunzip -c data/manifests/alimeeting-sdm_supervisions_train.jsonl.gz \
+ | jq ".text" | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text
+
+ # Prepare words segments
+ python ./local/text2segments.py \
+ --input $lang_char_dir/text \
+ --output $lang_char_dir/text_words_segmentation
+
+ cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
+ | sort -u | sed "/^$/d" \
+ | uniq > $lang_char_dir/words_no_ids.txt
+
+ # Prepare words.txt
+ if [ ! -f $lang_char_dir/words.txt ]; then
+ ./local/prepare_words.py \
+ --input-file $lang_char_dir/words_no_ids.txt \
+ --output-file $lang_char_dir/words.txt
+ fi
+
+ if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+ ./local/prepare_char.py
+ fi
+fi
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..1cfd053c7
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,419 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AlimeetingAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description=(
+ "These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc."
+ ),
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/manifests"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help=(
+ "When enabled, select noise from MUSAN and mix it "
+ "with training dataset. "
+ ),
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help=(
+ "When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding."
+ ),
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help=(
+ "Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch."
+ ),
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help=(
+ "The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used."
+ ),
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=100.0,
+ help=(
+ "Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM."
+ ),
+ )
+ group.add_argument(
+ "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=50,
+ help=(
+ "The number of buckets for the BucketingSampler"
+ "(you might want to increase it for larger datasets)."
+ ),
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help=(
+ "When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available."
+ ),
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help=(
+ "When enabled (=default), the examples will be "
+ "shuffled for each epoch."
+ ),
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help=(
+ "The number of training dataloader workers that " "collect the batches."
+ ),
+ )
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help=(
+ "Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp."
+ ),
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ "Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=2,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ if self.args.on_the_fly_feats:
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ )
+ else:
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ )
+
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ max_cuts=self.args.max_cuts,
+ shuffle=False,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=True,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts, max_duration=self.args.max_duration, shuffle=False
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ def remove_short_cuts(self, cut: Cut) -> bool:
+ """
+ See: https://github.com/k2-fsa/icefall/issues/500
+ Basically, the zipformer model subsamples the input using the following formula:
+ num_out_frames = ((num_in_frames - 7)//2 + 1)//2
+ For num_out_frames to be at least 1, num_in_frames must be at least 9.
+ """
+ return cut.duration >= 0.09
+
+ @lru_cache()
+ def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+ logging.info("About to get AMI train cuts")
+
+ def _remove_short_and_long_utt(c: Cut):
+ if c.duration < 0.1 or c.duration > 25.0:
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = c.supervisions[0].text
+ return T >= len(tokens)
+
+ cuts_train = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+ )
+
+ return cuts_train.filter(_remove_short_and_long_utt)
+
+ @lru_cache()
+ def eval_ihm_cuts(self) -> CutSet:
+ logging.info("About to get AliMeeting IHM eval cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_ihm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def eval_sdm_cuts(self) -> CutSet:
+ logging.info("About to get AliMeeting SDM eval cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_sdm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def eval_gss_cuts(self) -> CutSet:
+ if not (self.args.manifest_dir / "cuts_eval_gss.jsonl.gz").exists():
+ logging.info("No GSS dev cuts found")
+ return None
+ logging.info("About to get AliMeeting GSS-enhanced eval cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_eval_gss.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_ihm_cuts(self) -> CutSet:
+ logging.info("About to get AliMeeting IHM test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_sdm_cuts(self) -> CutSet:
+ logging.info("About to get AliMeeting SDM test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_gss_cuts(self) -> CutSet:
+ if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+ logging.info("No GSS test cuts found")
+ return None
+ logging.info("About to get AliMeeting GSS-enhanced test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..2741e0eeb
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,692 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 15 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 100
+ else:
+ log_interval = 2
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list(str(text).replace(" ", "")) for text in texts]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AlimeetingAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest_LG",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if "fast_beam_search" in params.decoding_method:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ alimeeting = AlimeetingAsrDataModule(args)
+
+ eval_ihm_cuts = alimeeting.eval_ihm_cuts()
+ test_ihm_cuts = alimeeting.test_ihm_cuts()
+ eval_sdm_cuts = alimeeting.eval_sdm_cuts()
+ test_sdm_cuts = alimeeting.test_sdm_cuts()
+ eval_gss_cuts = alimeeting.eval_gss_cuts()
+ test_gss_cuts = alimeeting.test_gss_cuts()
+
+ eval_ihm_dl = alimeeting.test_dataloaders(eval_ihm_cuts)
+ test_ihm_dl = alimeeting.test_dataloaders(test_ihm_cuts)
+ eval_sdm_dl = alimeeting.test_dataloaders(eval_sdm_cuts)
+ test_sdm_dl = alimeeting.test_dataloaders(test_sdm_cuts)
+ if eval_gss_cuts is not None:
+ eval_gss_dl = alimeeting.test_dataloaders(eval_gss_cuts)
+ if test_gss_cuts is not None:
+ test_gss_dl = alimeeting.test_dataloaders(test_gss_cuts)
+
+ test_sets = {
+ "eval_ihm": (eval_ihm_dl, eval_ihm_cuts),
+ "test_ihm": (test_ihm_dl, test_ihm_cuts),
+ "eval_sdm": (eval_sdm_dl, eval_sdm_cuts),
+ "test_sdm": (test_sdm_dl, test_sdm_cuts),
+ }
+ if eval_gss_cuts is not None:
+ test_sets["eval_gss"] = (eval_gss_dl, eval_gss_cuts)
+ if test_gss_cuts is not None:
+ test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+ for test_set in test_sets:
+ logging.info(f"Decoding {test_set}")
+ dl, cuts = test_sets[test_set]
+ results_dict = decode_dataset(
+ dl=dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
new file mode 100755
index 000000000..23a88dd29
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7/export.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7/export.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 20 \
+ --avg 10
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./pruned_transducer_stateless7/decode.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+ # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=15,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=8,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named cpu_jit.pt
+
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ logging.info("Using torch.jit.script()")
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
new file mode 120000
index 000000000..a44034e34
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
new file mode 120000
index 000000000..068f0f57f
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
new file mode 120000
index 000000000..7ceac5d10
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..757d6535e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1186 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 15 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 150 \
+ --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=15,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=5000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=10,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 100,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+
+ y = graph_compiler.texts_to_ids(texts)
+ if type(y) == list:
+ y = k2.RaggedTensor(y).to(device)
+ else:
+ y = y.to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ alimeeting = AlimeetingAsrDataModule(args)
+
+ train_cuts = alimeeting.train_cuts()
+ train_dl = alimeeting.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = alimeeting.eval_ihm_cuts()
+ valid_dl = alimeeting.valid_dataloaders(valid_cuts)
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # graph_compiler=graph_compiler,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AlimeetingAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR_v2/shared b/egs/alimeeting/ASR_v2/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/alimeeting/ASR_v2/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file
diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md
new file mode 100644
index 000000000..1c9714bd4
--- /dev/null
+++ b/egs/ami/ASR/README.md
@@ -0,0 +1,48 @@
+# AMI
+
+This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's
+headset and lapel microphones, and also 2 array microphones containing 8 channels each.
+We pool data in the following 4 ways and train a single model on the pooled data:
+
+(i) individual headset microphone (IHM)
+(ii) IHM with simulated reverb
+(iii) Single distant microphone (SDM)
+(iv) GSS-enhanced array microphones
+
+Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled
+data. Here are the statistics of the combined training data:
+
+```python
+>>> cuts_train.describe()
+Cuts count: 1222053
+Total duration (hh:mm:ss): 905:00:28
+Speech duration (hh:mm:ss): 905:00:28 (99.9%)
+Duration statistics (seconds):
+mean 2.7
+std 2.8
+min 0.0
+25% 0.6
+50% 1.6
+75% 3.8
+99% 12.3
+99.5% 13.9
+99.9% 18.4
+max 36.8
+```
+
+**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement
+of far-field array microphones, but this is optional (see `prepare.sh` for details).
+
+## Performance Record
+
+### pruned_transducer_stateless7
+
+The following are decoded using `modified_beam_search`:
+
+| Evaluation set | dev WER | test WER |
+|--------------------------|------------|---------|
+| IHM | 18.92 | 17.40 |
+| SDM | 31.25 | 32.21 |
+| MDM (GSS-enhanced) | 21.67 | 22.43 |
+
+See [RESULTS](/egs/ami/ASR/RESULTS.md) for details.
diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md
new file mode 100644
index 000000000..163986021
--- /dev/null
+++ b/egs/ami/ASR/RESULTS.md
@@ -0,0 +1,92 @@
+## Results
+
+### AMI training results (Pruned Transducer)
+
+#### 2022-11-20
+
+#### Zipformer (pruned_transducer_stateless7)
+
+Zipformer encoder + non-current decoder. The decoder
+contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
+layer (to transform tensor dim).
+
+All the results below are using a single model that is trained by combining the following
+data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise
+augmentation are applied on top of the pooled data.
+
+**WERs for IHM:**
+
+| | dev | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 19.25 | 17.83 | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search | 18.92 | 17.40 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 19.44 | 18.04 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for SDM:**
+
+| | dev | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 31.32 | 32.38 | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search | 31.25 | 32.21 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 31.11 | 32.10 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+**WERs for GSS-enhanced MDM:**
+
+| | dev | test | comment |
+|---------------------------|------------|------------|------------------------------------------|
+| greedy search | 22.05 | 22.93 | --epoch 14 --avg 8 --max-duration 500 |
+| modified beam search | 21.67 | 22.43 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 |
+| fast beam search | 22.21 | 22.83 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 15 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 150 \
+ --max-cuts 150 \
+ --prune-range 5 \
+ --lr-factor 5 \
+ --lm-scale 0.25 \
+ --use-fp16 True
+```
+
+The decoding command is:
+```
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 14 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method greedy_search
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+# fast beam search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+Pretrained model is available at
+
+The tensorboard training log can be found at
+
diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py
new file mode 100755
index 000000000..4892b40e3
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_ami.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the AMI dataset.
+For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced
+audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced
+parts (which are the 3 evaluation settings).
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+import math
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_ami():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ logging.info("Reading manifests")
+ manifests_ihm = read_manifests_if_cached(
+ dataset_parts=["train", "dev", "test"],
+ output_dir=src_dir,
+ prefix="ami-ihm",
+ suffix="jsonl.gz",
+ )
+ manifests_sdm = read_manifests_if_cached(
+ dataset_parts=["train", "dev", "test"],
+ output_dir=src_dir,
+ prefix="ami-sdm",
+ suffix="jsonl.gz",
+ )
+ # For GSS we already have cuts so we read them directly.
+ manifests_gss = read_manifests_if_cached(
+ dataset_parts=["train", "dev", "test"],
+ output_dir=src_dir,
+ prefix="ami-gss",
+ suffix="jsonl.gz",
+ )
+
+ def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
+ cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
+ _ = cuts.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=storage_path,
+ manifest_path=manifest_path,
+ batch_duration=5000,
+ num_workers=8,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ logging.info(
+ "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)"
+ )
+
+ logging.info("Processing train split IHM")
+ cuts_ihm = (
+ CutSet.from_manifests(**manifests_ihm["train"])
+ .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+ .modify_ids(lambda x: x + "-ihm")
+ )
+ _extract_feats(
+ cuts_ihm,
+ output_dir / "feats_train_ihm",
+ src_dir / "cuts_train_ihm.jsonl.gz",
+ )
+
+ logging.info("Processing train split IHM + reverberated IHM")
+ cuts_ihm_rvb = cuts_ihm.reverb_rir()
+ _extract_feats(
+ cuts_ihm_rvb,
+ output_dir / "feats_train_ihm_rvb",
+ src_dir / "cuts_train_ihm_rvb.jsonl.gz",
+ )
+
+ logging.info("Processing train split SDM")
+ cuts_sdm = (
+ CutSet.from_manifests(**manifests_sdm["train"])
+ .trim_to_supervisions(keep_overlapping=False)
+ .modify_ids(lambda x: x + "-sdm")
+ )
+ _extract_feats(
+ cuts_sdm,
+ output_dir / "feats_train_sdm",
+ src_dir / "cuts_train_sdm.jsonl.gz",
+ )
+
+ logging.info("Processing train split GSS")
+ cuts_gss = (
+ CutSet.from_manifests(**manifests_gss["train"])
+ .trim_to_supervisions(keep_overlapping=False)
+ .modify_ids(lambda x: x + "-gss")
+ )
+ _extract_feats(
+ cuts_gss,
+ output_dir / "feats_train_gss",
+ src_dir / "cuts_train_gss.jsonl.gz",
+ )
+
+ logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
+ for split in ["dev", "test"]:
+ logging.info(f"Processing {split} IHM")
+ cuts_ihm = (
+ CutSet.from_manifests(**manifests_ihm[split])
+ .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_ihm",
+ manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz",
+ batch_duration=5000,
+ num_workers=8,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+ logging.info(f"Processing {split} SDM")
+ cuts_sdm = (
+ CutSet.from_manifests(**manifests_sdm[split])
+ .trim_to_supervisions(keep_overlapping=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_sdm",
+ manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz",
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+ logging.info(f"Processing {split} GSS")
+ cuts_gss = (
+ CutSet.from_manifests(**manifests_gss[split])
+ .trim_to_supervisions(keep_overlapping=False)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"feats_{split}_gss",
+ manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz",
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_ami()
diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py
new file mode 100755
index 000000000..1fcf951f9
--- /dev/null
+++ b/egs/ami/ASR/local/compute_fbank_musan.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, LilcomChunkyWriter, combine
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ dataset_parts = (
+ "music",
+ "speech",
+ "noise",
+ )
+ prefix = "musan"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
+
+ if musan_cuts_path.is_file():
+ logging.info(f"{musan_cuts_path} already exists - skipping")
+ return
+
+ logging.info("Extracting features for Musan")
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ # create chunks of Musan with duration 5 - 10 seconds
+ _ = (
+ CutSet.from_manifests(
+ recordings=combine(part["recordings"] for part in manifests.values())
+ )
+ .cut_into_windows(10.0)
+ .filter(lambda c: c.duration > 5)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / "musan_feats",
+ manifest_path=musan_cuts_path,
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_fbank_musan()
diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py
new file mode 100644
index 000000000..bed220eb3
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_enhanced.py
@@ -0,0 +1,158 @@
+#!/usr/local/bin/python
+# -*- coding: utf-8 -*-
+# Data preparation for AMI GSS-enhanced dataset.
+
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+
+from lhotse import Recording, RecordingSet, SupervisionSet
+from lhotse.qa import fix_manifests
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fastcopy
+from tqdm import tqdm
+
+logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s",
+ level=logging.INFO,
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+
+def get_args():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.")
+ parser.add_argument(
+ "manifests_dir",
+ type=Path,
+ help="Path to directory containing AMI manifests.",
+ )
+ parser.add_argument(
+ "enhanced_dir",
+ type=Path,
+ help="Path to enhanced data directory.",
+ )
+ parser.add_argument(
+ "--num-jobs",
+ "-j",
+ type=int,
+ default=1,
+ help="Number of parallel jobs to run.",
+ )
+ parser.add_argument(
+ "--min-segment-duration",
+ "-d",
+ type=float,
+ default=0.0,
+ help="Minimum duration of a segment in seconds.",
+ )
+ return parser.parse_args()
+
+
+def find_recording_and_create_new_supervision(enhanced_dir, supervision):
+ """
+ Given a supervision (corresponding to original AMI recording), this function finds the
+ enhanced recording correspoding to the supervision, and returns this recording and
+ a new supervision whose start and end times are adjusted to match the enhanced recording.
+ """
+ file_name = Path(
+ f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac"
+ )
+ save_path = enhanced_dir / f"{supervision.recording_id}" / file_name
+ if save_path.exists():
+ recording = Recording.from_file(save_path)
+ if recording.duration == 0:
+ logging.warning(f"Skipping {save_path} which has duration 0 seconds.")
+ return None
+
+ # Old supervision is wrt to the original recording, we create new supervision
+ # wrt to the enhanced segment
+ new_supervision = fastcopy(
+ supervision,
+ recording_id=recording.id,
+ start=0,
+ duration=recording.duration,
+ )
+ return recording, new_supervision
+ else:
+ logging.warning(f"{save_path} does not exist.")
+ return None
+
+
+def main(args):
+ # Get arguments
+ manifests_dir = args.manifests_dir
+ enhanced_dir = args.enhanced_dir
+
+ # Load manifests from cache if they exist (saves time)
+ manifests = read_manifests_if_cached(
+ dataset_parts=["train", "dev", "test"],
+ output_dir=manifests_dir,
+ prefix="ami-sdm",
+ suffix="jsonl.gz",
+ )
+ if not manifests:
+ raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir))
+
+ with ThreadPoolExecutor(args.num_jobs) as ex:
+ for part in ["train", "dev", "test"]:
+ logging.info(f"Processing {part}...")
+ supervisions_orig = manifests[part]["supervisions"].filter(
+ lambda s: s.duration >= args.min_segment_duration
+ )
+ # Remove TS3009d supervisions since they are not present in the enhanced data
+ supervisions_orig = supervisions_orig.filter(
+ lambda s: s.recording_id != "TS3009d"
+ )
+ futures = []
+
+ for supervision in tqdm(
+ supervisions_orig,
+ desc="Distributing tasks",
+ ):
+ futures.append(
+ ex.submit(
+ find_recording_and_create_new_supervision,
+ enhanced_dir,
+ supervision,
+ )
+ )
+
+ recordings = []
+ supervisions = []
+ for future in tqdm(
+ futures,
+ total=len(futures),
+ desc="Processing tasks",
+ ):
+ result = future.result()
+ if result is not None:
+ recording, new_supervision = result
+ recordings.append(recording)
+ supervisions.append(new_supervision)
+
+ # Remove duplicates from the recordings
+ recordings_nodup = {}
+ for recording in recordings:
+ if recording.id not in recordings_nodup:
+ recordings_nodup[recording.id] = recording
+ else:
+ logging.warning("Recording {} is duplicated.".format(recording.id))
+ recordings = RecordingSet.from_recordings(recordings_nodup.values())
+ supervisions = SupervisionSet.from_segments(supervisions)
+
+ recordings, supervisions = fix_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ logging.info(f"Writing {part} enhanced manifests")
+ recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz")
+ supervisions.to_file(
+ manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz"
+ )
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh
new file mode 100755
index 000000000..d5422458b
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_ami_gss.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+# This script is used to run GSS-based enhancement on AMI data.
+set -euo pipefail
+nj=4
+stage=0
+
+. shared/parse_options.sh || exit 1
+
+if [ $# != 2 ]; then
+ echo "Wrong #arguments ($#, expected 2)"
+ echo "Usage: local/prepare_ami_gss.sh [options] "
+ echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss"
+ echo "main options (for others, see top of script file)"
+ echo " --nj # number of parallel jobs"
+ echo " --stage # stage to start running from"
+ exit 1;
+fi
+
+DATA_DIR=$1
+EXP_DIR=$2
+
+mkdir -p $EXP_DIR
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 1 ]; then
+ log "Stage 1: Prepare cut sets"
+ for part in train dev test; do
+ lhotse cut simple \
+ -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \
+ -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \
+ $EXP_DIR/cuts_${part}.jsonl.gz
+ done
+fi
+
+if [ $stage -le 2 ]; then
+ log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)"
+ for part in train dev test; do
+ lhotse cut trim-to-supervisions --discard-overlapping \
+ $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz
+ done
+fi
+
+if [ $stage -le 3 ]; then
+ log "Stage 3: Split manifests for multi-GPU processing (optional)"
+ for part in train; do
+ gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_${part}_split$nj
+ done
+fi
+
+if [ $stage -le 4 ]; then
+ log "Stage 4: Enhance train segments using GSS (requires GPU)"
+ # for train, we use smaller context and larger batches to speed-up processing
+ for JOB in $(seq $nj); do
+ gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \
+ --bss-iterations 10 \
+ --context-duration 5.0 \
+ --use-garbage-class \
+ --channels 0,1,2,3,4,5,6,7 \
+ --min-segment-length 0.05 \
+ --max-segment-length 35.0 \
+ --max-batch-duration 60.0 \
+ --num-buckets 3 \
+ --num-workers 2
+ done
+fi
+
+if [ $stage -le 5 ]; then
+ log "Stage 5: Enhance dev/test segments using GSS (using GPU)"
+ # for dev/test, we use larger context and smaller batches to get better quality
+ for part in dev test; do
+ for JOB in $(seq $nj); do
+ gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \
+ $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \
+ $EXP_DIR/enhanced \
+ --bss-iterations 10 \
+ --context-duration 15.0 \
+ --use-garbage-class \
+ --channels 0,1,2,3,4,5,6,7 \
+ --min-segment-length 0.05 \
+ --max-segment-length 30.0 \
+ --max-batch-duration 45.0 \
+ --num-buckets 3 \
+ --num-workers 2
+ done
+ done
+fi
+
+if [ $stage -le 6 ]; then
+ log "Stage 6: Prepare manifests for GSS-enhanced data"
+ python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05
+fi
diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/ami/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/ami/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh
new file mode 100755
index 000000000..fb21a8ec6
--- /dev/null
+++ b/egs/ami/ASR/prepare.sh
@@ -0,0 +1,144 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+use_gss=true # Use GSS-based enhancement with MDM setting
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/amicorpus
+# You can find audio and transcripts in this path.
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+#
+# - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19}
+# These contain the Fisher English audio and transcripts. We will
+# only use the transcripts as extra LM training data (similar to Kaldi).
+#
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+vocab_size=500
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/amicorpus,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/amicorpus $dl_dir/amicorpus
+ #
+ if [ ! -d $dl_dir/amicorpus ]; then
+ lhotse download ami --mic ihm $dl_dir/amicorpus
+ lhotse download ami --mic mdm $dl_dir/amicorpus
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare AMI manifests"
+ # We assume that you have downloaded the AMI corpus
+ # to $dl_dir/amicorpus. We perform text normalization for the transcripts.
+ mkdir -p data/manifests
+ for mic in ihm sdm mdm; do
+ lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \
+ --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then
+ log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)"
+ # We assume that you have installed the GSS package: https://github.com/desh2608/gss
+ local/prepare_ami_gss.sh data/manifests exp/ami_gss
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank features for AMI"
+ mkdir -p data/fbank
+ python local/compute_fbank_ami.py
+ log "Combine features from train splits"
+ lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
+ gzip -c > data/manifests/cuts_train_all.jsonl.gz
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank features for musan"
+ mkdir -p data/fbank
+ python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Dump transcripts for BPE model training."
+ mkdir -p data/lm
+ cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Prepare BPE based lang"
+
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ # Add special words to words.txt
+ echo " 0" > $lang_dir/words.txt
+ echo "!SIL 1" >> $lang_dir/words.txt
+ echo " 2" >> $lang_dir/words.txt
+
+ # Add regular words to words.txt
+ cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
+
+ # Add remaining special word symbols expected by LM scripts.
+ num_words=$(cat $lang_dir/words.txt | wc -l)
+ echo " ${num_words}" >> $lang_dir/words.txt
+ num_words=$(cat $lang_dir/words.txt | wc -l)
+ echo " ${num_words}" >> $lang_dir/words.txt
+ num_words=$(cat $lang_dir/words.txt | wc -l)
+ echo "#0 ${num_words}" >> $lang_dir/words.txt
+
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript data/lm/transcript_words.txt
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+ fi
+fi
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..f7ee9c962
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,430 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AmiAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description=(
+ "These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc."
+ ),
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/manifests"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help=(
+ "When enabled, select noise from MUSAN and mix it "
+ "with training dataset. "
+ ),
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help=(
+ "When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding."
+ ),
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help=(
+ "Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch."
+ ),
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help=(
+ "The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used."
+ ),
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=100.0,
+ help=(
+ "Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM."
+ ),
+ )
+ group.add_argument(
+ "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch."
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=50,
+ help=(
+ "The number of buckets for the BucketingSampler"
+ "(you might want to increase it for larger datasets)."
+ ),
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help=(
+ "When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available."
+ ),
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help=(
+ "When enabled (=default), the examples will be "
+ "shuffled for each epoch."
+ ),
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help=(
+ "The number of training dataloader workers that " "collect the batches."
+ ),
+ )
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help=(
+ "Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp."
+ ),
+ )
+ group.add_argument(
+ "--ihm-only",
+ type=str2bool,
+ default=False,
+ help="When enabled, only use IHM data for training.",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ "Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=2,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ if self.args.on_the_fly_feats:
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ )
+ else:
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ )
+
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ max_cuts=self.args.max_cuts,
+ shuffle=False,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=True,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts, max_duration=self.args.max_duration, shuffle=False
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ def remove_short_cuts(self, cut: Cut) -> bool:
+ """
+ See: https://github.com/k2-fsa/icefall/issues/500
+ Basically, the zipformer model subsamples the input using the following formula:
+ num_out_frames = (num_in_frames - 7)//2
+ For num_out_frames to be at least 1, num_in_frames must be at least 9.
+ """
+ return cut.duration >= 0.09
+
+ @lru_cache()
+ def train_cuts(self, sp: Optional[Any] = None) -> CutSet:
+ logging.info("About to get AMI train cuts")
+
+ def _remove_short_and_long_utt(c: Cut):
+ if c.duration < 0.2 or c.duration > 25.0:
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+ return T >= len(tokens)
+
+ if self.args.ihm_only:
+ cuts_train = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
+ )
+ else:
+ cuts_train = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_train_all.jsonl.gz"
+ )
+
+ return cuts_train.filter(_remove_short_and_long_utt)
+
+ @lru_cache()
+ def dev_ihm_cuts(self) -> CutSet:
+ logging.info("About to get AMI IHM dev cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def dev_sdm_cuts(self) -> CutSet:
+ logging.info("About to get AMI SDM dev cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def dev_gss_cuts(self) -> CutSet:
+ if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
+ logging.info("No GSS dev cuts found")
+ return None
+ logging.info("About to get AMI GSS-enhanced dev cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_ihm_cuts(self) -> CutSet:
+ logging.info("About to get AMI IHM test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_sdm_cuts(self) -> CutSet:
+ logging.info("About to get AMI SDM test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
+
+ @lru_cache()
+ def test_gss_cuts(self) -> CutSet:
+ if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
+ logging.info("No GSS test cuts found")
+ return None
+ logging.info("About to get AMI GSS-enhanced test cuts")
+ cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz")
+ return cs.filter(self.remove_short_cuts)
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..37516affc
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..9999894d1
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,739 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 100 \
+ --decoding-method greedy_search
+
+(2) beam search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 500 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless7/decode.py \
+ --iter 105000 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+ word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ word_table:
+ The word symbol table.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+ word_table: Optional[k2.SymbolTable] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 100
+ else:
+ log_interval = 2
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ test_set_cers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt"
+ with open(wers_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ # we also compute CER for AMI dataset.
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+ cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt"
+ with open(cers_filename, "w") as f:
+ cer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_cers[key] = cer
+
+ logging.info("Wrote detailed error stats to {}".format(wers_filename))
+
+ test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])}
+ test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])}
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER\tCER", file=f)
+ for key in test_set_wers:
+ print(
+ "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]),
+ file=f,
+ )
+
+ s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key in test_set_wers:
+ s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AmiAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest_LG",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(f"{params.lang_dir}/bpe.model")
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ ami = AmiAsrDataModule(args)
+
+ dev_ihm_cuts = ami.dev_ihm_cuts()
+ test_ihm_cuts = ami.test_ihm_cuts()
+ dev_sdm_cuts = ami.dev_sdm_cuts()
+ test_sdm_cuts = ami.test_sdm_cuts()
+ dev_gss_cuts = ami.dev_gss_cuts()
+ test_gss_cuts = ami.test_gss_cuts()
+
+ dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts)
+ test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
+ dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
+ test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
+ if dev_gss_cuts is not None:
+ dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
+ if test_gss_cuts is not None:
+ test_gss_dl = ami.test_dataloaders(test_gss_cuts)
+
+ test_sets = {
+ "dev_ihm": (dev_ihm_dl, dev_ihm_cuts),
+ "test_ihm": (test_ihm_dl, test_ihm_cuts),
+ "dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
+ "test_sdm": (test_sdm_dl, test_sdm_cuts),
+ }
+ if dev_gss_cuts is not None:
+ test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
+ if test_gss_cuts is not None:
+ test_sets["test_gss"] = (test_gss_dl, test_gss_cuts)
+
+ for test_set in test_sets:
+ logging.info(f"Decoding {test_set}")
+ dl, cuts = test_sets[test_set]
+ results_dict = decode_dataset(
+ dl=dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py
new file mode 120000
index 000000000..2713792e6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/export.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..81823ced2
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1193 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 15 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 150 \
+ --use-fp16 True
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AmiAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=11,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=5000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=10,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 100,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"]
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = supervisions["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = ((feature_lens - 7) // 2).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ ami = AmiAsrDataModule(args)
+
+ # Here is the duration statistics of the training set.
+ # Cuts count: 1230033
+ # Total duration (hh:mm:ss): 904:25:34
+ # Speech duration (hh:mm:ss): 904:25:34 (100.0%)
+ # Duration statistics (seconds):
+ # mean 2.6
+ # std 2.8
+ # min 0.0
+ # 25% 0.6
+ # 50% 1.6
+ # 75% 3.8
+ # 99% 12.3
+ # 99.5% 13.9
+ # 99.9% 18.3
+ # max 36.8
+
+ train_cuts = ami.train_cuts(sp=sp)
+ train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict)
+
+ valid_cuts = ami.dev_ihm_cuts()
+ valid_dl = ami.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AmiAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/ami/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/README.md b/egs/commonvoice/ASR/README.md
new file mode 100644
index 000000000..a4582499b
--- /dev/null
+++ b/egs/commonvoice/ASR/README.md
@@ -0,0 +1,18 @@
+# Introduction
+
+This recipe includes some different ASR models trained with Common Voice
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
+| `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan |
+
+The decoder in `transducer_stateless` is modified from the paper
+[RNN-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md
new file mode 100644
index 000000000..751625371
--- /dev/null
+++ b/egs/commonvoice/ASR/RESULTS.md
@@ -0,0 +1,59 @@
+## Results
+### GigaSpeech BPE training results (Pruned Stateless Transducer 7)
+
+#### [pruned_transducer_stateless7](./pruned_transducer_stateless7)
+
+See #997 for more details.
+
+Number of model parameters: 70369391, i.e., 70.37 M
+
+The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below:
+
+Results are:
+
+| | Dev | Test |
+|----------------------|-------|-------|
+| greedy search | 9.96 | 12.54 |
+| modified beam search | 9.86 | 12.48 |
+
+To reproduce the above result, use the following commands for training:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 550
+```
+
+and the following commands for decoding:
+
+```bash
+# greedy search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 30 \
+ --avg 5 \
+ --decoding-method greedy_search \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --bpe-model data/en/lang_bpe_500/bpe.model \
+ --max-duration 600
+
+# modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 30 \
+ --avg 5 \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --bpe-model data/en/lang_bpe_500/bpe.model \
+ --max-duration 600
+```
+
+Pretrained model is available at
+
+
+The tensorboard log for training is available at
+
diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
new file mode 100755
index 000000000..c8f9b6ccb
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the CommonVoice dataset.
+It looks for manifests in the directory data/${lang}/manifests.
+
+The generated fbank features are saved in data/${lang}/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from filter_cuts import filter_cuts
+from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--language",
+ type=str,
+ help="""Language of Common Voice""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_commonvoice_dev_test(language: str):
+ src_dir = Path(f"data/{language}/manifests")
+ output_dir = Path(f"data/{language}/fbank")
+ num_workers = 42
+ batch_duration = 600
+
+ subsets = ("dev", "test")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+
+ logging.info(f"device: {device}")
+
+ for partition in subsets:
+ cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+
+ raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz"
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Splitting cuts into smaller chunks")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info("Computing features")
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/cv-{language}_feats_{partition}",
+ num_workers=num_workers,
+ batch_duration=batch_duration,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ compute_fbank_commonvoice_dev_test(language=args.language)
diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
new file mode 100755
index 000000000..0564f6ec6
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ LilcomChunkyWriter,
+ set_audio_duration_mismatch_tolerance,
+ set_caching_enabled,
+)
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--language",
+ type=str,
+ help="""Language of Common Voice""",
+ )
+
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=20,
+ help="Number of dataloading workers used for reading the audio.",
+ )
+
+ parser.add_argument(
+ "--batch-duration",
+ type=float,
+ default=600.0,
+ help="The maximum number of audio seconds in a batch."
+ "Determines batch size dynamically.",
+ )
+
+ parser.add_argument(
+ "--num-splits",
+ type=int,
+ required=True,
+ help="The number of splits of the train subset",
+ )
+
+ parser.add_argument(
+ "--start",
+ type=int,
+ default=0,
+ help="Process pieces starting from this number (inclusive).",
+ )
+
+ parser.add_argument(
+ "--stop",
+ type=int,
+ default=-1,
+ help="Stop processing pieces until this number (exclusive).",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_commonvoice_splits(args):
+ subset = "train"
+ num_splits = args.num_splits
+ language = args.language
+ output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}"
+ output_dir = Path(output_dir)
+ assert output_dir.exists(), f"{output_dir} does not exist!"
+
+ num_digits = len(str(num_splits))
+
+ start = args.start
+ stop = args.stop
+ if stop < start:
+ stop = num_splits
+
+ stop = min(stop, num_splits)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+ logging.info(f"device: {device}")
+
+ set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
+ set_caching_enabled(False)
+ for i in range(start, stop):
+ idx = f"{i + 1}".zfill(num_digits)
+ logging.info(f"Processing {idx}/{num_splits}")
+
+ cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{cuts_path} exists - skipping")
+ continue
+
+ raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz"
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Splitting cuts into smaller chunks.")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info("Computing features")
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}",
+ num_workers=args.num_workers,
+ batch_duration=args.batch_duration,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+
+
+def main():
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ compute_fbank_commonvoice_splits(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/local/compute_fbank_musan.py b/egs/commonvoice/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/filter_cuts.py b/egs/commonvoice/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/commonvoice/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/prepare_lang_bpe.py b/egs/commonvoice/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/commonvoice/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
new file mode 100755
index 000000000..c5ec14502
--- /dev/null
+++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import re
+from pathlib import Path
+from typing import Optional
+
+from lhotse import CutSet, SupervisionSegment
+from lhotse.recipes.utils import read_manifests_if_cached
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ parser.add_argument(
+ "--language",
+ type=str,
+ help="""Language of Common Voice""",
+ )
+
+ return parser.parse_args()
+
+
+def normalize_text(utt: str) -> str:
+ utt = re.sub(r"[{0}]+".format("-"), " ", utt)
+ return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
+
+
+def preprocess_commonvoice(
+ language: str,
+ dataset: Optional[str] = None,
+):
+ src_dir = Path(f"data/{language}/manifests")
+ output_dir = Path(f"data/{language}/fbank")
+ output_dir.mkdir(exist_ok=True)
+
+ if dataset is None:
+ dataset_parts = (
+ "dev",
+ "test",
+ "train",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ logging.info("Loading manifest")
+ prefix = f"cv-{language}"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ suffix=suffix,
+ prefix=prefix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ for partition, m in manifests.items():
+ logging.info(f"Processing {partition}")
+ raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
+ if raw_cuts_path.is_file():
+ logging.info(f"{partition} already exists - skipping")
+ continue
+
+ logging.info(f"Normalizing text in {partition}")
+ for sup in m["supervisions"]:
+ text = str(sup.text)
+ orig_text = text
+ sup.text = normalize_text(sup.text)
+ text = str(sup.text)
+ if len(orig_text) != len(text):
+ logging.info(
+ f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
+ )
+
+ # Create long-recording cut manifests.
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ ).resample(16000)
+
+ # Run data augmentation that needs to be done in the
+ # time domain.
+ logging.info(f"Saving to {raw_cuts_path}")
+ cut_set.to_file(raw_cuts_path)
+
+
+def main():
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ preprocess_commonvoice(
+ language=args.language,
+ dataset=args.dataset,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/local/train_bpe_model.py b/egs/commonvoice/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/commonvoice/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/validate_bpe_lexicon.py b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/commonvoice/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh
new file mode 100755
index 000000000..7a583f9c8
--- /dev/null
+++ b/egs/commonvoice/ASR/prepare.sh
@@ -0,0 +1,244 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=16
+stage=-1
+stop_stage=100
+
+# Split data/${lang}set to this number of pieces
+# This is to avoid OOM during feature extraction.
+num_splits=1000
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/$release/$lang
+# This directory contains the following files downloaded from
+# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz
+#
+# - clips
+# - dev.tsv
+# - invalidated.tsv
+# - other.tsv
+# - reported.tsv
+# - test.tsv
+# - train.tsv
+# - validated.tsv
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+release=cv-corpus-13.0-2023-03-09
+lang=en
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/${lang}/lang_bpe_xxx,
+# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 5000
+ # 2000
+ # 1000
+ 500
+)
+
+# All files generated by this script are saved in "data/${lang}".
+# You can safely remove "data/${lang}" and rerun this script to regenerate it.
+mkdir -p data/${lang}
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/$release,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/$release $dl_dir/$release
+ #
+ if [ ! -d $dl_dir/$release/$lang/clips ]; then
+ lhotse download commonvoice --languages $lang --release $release $dl_dir
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare CommonVoice manifest"
+ # We assume that you have downloaded the CommonVoice corpus
+ # to $dl_dir/$release
+ mkdir -p data/${lang}/manifests
+ if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
+ lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
+ touch data/${lang}/manifests/.cv-${lang}.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.musan.done ]; then
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Preprocess CommonVoice manifest"
+ if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then
+ ./local/preprocess_commonvoice.py --language $lang
+ touch data/${lang}/fbank/.preprocess_complete
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for dev and test subsets of CommonVoice"
+ mkdir -p data/${lang}/fbank
+ if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then
+ ./local/compute_fbank_commonvoice_dev_test.py --language $lang
+ touch data/${lang}/fbank/.cv-${lang}_dev_test.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Split train subset into ${num_splits} pieces"
+ split_dir=data/${lang}/fbank/cv-${lang}_train_split_${num_splits}
+ if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then
+ lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
+ touch $split_dir/.cv-${lang}_train_split.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Compute features for train subset of CommonVoice"
+ if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
+ ./local/compute_fbank_commonvoice_splits.py \
+ --num-workers $nj \
+ --batch-duration 600 \
+ --start 0 \
+ --num-splits $num_splits \
+ --language $lang
+ touch data/${lang}/fbank/.cv-${lang}_train.done
+ fi
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Combine features for train"
+ if [ ! -f data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz ]; then
+ pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
+ lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
+ fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Compute fbank for musan"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.musan.done
+ fi
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ file=$(
+ find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz"
+ )
+ gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
+
+ # Ensure space only appears once
+ sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
+ sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' > $lang_dir/words.txt
+ (echo '!SIL'; echo ''; echo ''; ) |
+ cat - $lang_dir/words.txt | sort | uniq | awk '
+ BEGIN {
+ print " 0";
+ }
+ {
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ printf("%s %d\n", $1, NR);
+ }
+ END {
+ printf("#0 %d\n", NR+1);
+ printf(" %d\n", NR+2);
+ printf(" %d\n", NR+3);
+ }' > $lang_dir/words || exit 1;
+ mv $lang_dir/words $lang_dir/words.txt
+ fi
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
+
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+fi
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
new file mode 100644
index 000000000..2c37244a4
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -0,0 +1,420 @@
+# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class CommonVoiceAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. CommonVoice test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--language",
+ type=str,
+ default="en",
+ help="""Language of Common Voice""",
+ )
+ group.add_argument(
+ "--cv-manifest-dir",
+ type=Path,
+ default=Path("data/en/fbank"),
+ help="Path to directory with CommonVoice train/dev/test cuts.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with the other cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
+ )
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py
new file mode 100755
index 000000000..52b2fbcab
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py
@@ -0,0 +1,962 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(8) modified beam search with RNNLM shallow fusion
+./pruned_transducer_stateless5/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search_lm_shallow_fusion \
+ --beam-size 4 \
+ --lm-type rnn \
+ --lm-scale 0.3 \
+ --lm-exp-dir /path/to/LM \
+ --rnn-lm-epoch 99 \
+ --rnn-lm-avg 1 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1
+
+(9) modified beam search with LM shallow fusion + LODR
+./pruned_transducer_stateless5/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --max-duration 600 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --decoding-method modified_beam_search_LODR \
+ --beam-size 4 \
+ --lm-type rnn \
+ --lm-scale 0.4 \
+ --lm-exp-dir /path/to/LM \
+ --rnn-lm-epoch 99 \
+ --rnn-lm-avg 1 \
+ --rnn-lm-num-layers 3 \
+ --rnn-lm-tie-weights 1
+ --tokens-ngram 2 \
+ --ngram-lm-scale -0.16 \
+
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+ modified_beam_search_ngram_rescoring,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/en/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/en/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ - modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
+ - modified_beam_search_LODR
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=3,
+ help="""Token Ngram used for rescoring.
+ Used only when the decoding method is
+ modified_beam_search_ngram_rescoring, or LODR
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="""ID of the backoff symbol.
+ Used only when the decoding method is
+ modified_beam_search_ngram_rescoring""",
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ ngram_lm: Optional[NgramLm] = None,
+ ngram_lm_scale: float = 1.0,
+ LM: Optional[LmScorer] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
+ set to true.
+ ngram_lm:
+ A ngram lm. Used in LODR decoding.
+ ngram_lm_scale:
+ The scale of the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ ngram_lm: Optional[NgramLm] = None,
+ ngram_lm_scale: float = 1.0,
+ LM: Optional[LmScorer] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network LM, used during shallow fusion
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ LM=LM,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if "ngram" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ if params.use_shallow_fusion:
+ if params.lm_type == "rnn":
+ params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
+ elif params.lm_type == "transformer":
+ params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load N-gram LM when needed
+ if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"lm filename: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ # only load the neural network LM if doing shallow fusion
+ if params.use_shallow_fusion:
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+
+ else:
+ LM = None
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ dev_cuts = commonvoice.dev_cuts()
+ test_cuts = commonvoice.test_cuts()
+
+ dev_dl = commonvoice.valid_dataloaders(dev_cuts)
+ test_dl = commonvoice.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ LM=LM,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py
new file mode 100755
index 000000000..0c98885ac
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export-onnx.py
@@ -0,0 +1,600 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Yifan Yang)
+
+"""
+This script exports a transducer model from PyTorch to ONNX.
+
+We use the pre-trained model from
+https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-9999.pt
+popd
+
+2. Export the model to ONNX
+
+./pruned_transducer_stateless7/export-onnx.py \
+ --bpe-model $repo/data/en/lang_bpe_500/bpe.model \
+ --use-averaged-model 0 \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-9999-avg-1.onnx
+ - decoder-epoch-9999-avg-1.onnx
+ - joiner-epoch-9999-avg-1.onnx
+
+See ./onnx_pretrained.py and ./onnx_check.py for how to
+use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from decoder import Decoder
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+from zipformer import Zipformer
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import setup_logger, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/en/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxEncoder(nn.Module):
+ """A wrapper for Zipformer and the encoder_proj from the joiner"""
+
+ def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear):
+ """
+ Args:
+ encoder:
+ A Zipformer encoder.
+ encoder_proj:
+ The projection layer for encoder from the joiner.
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_proj = encoder_proj
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of Zipformer.forward
+
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 1-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a tuple containing:
+ - encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
+ - encoder_out_lens, A 1-D tensor of shape (N,)
+ """
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens)
+
+ encoder_out = self.encoder_proj(encoder_out)
+ # Now encoder_out is of shape (N, T, joiner_dim)
+
+ return encoder_out, encoder_out_lens
+
+
+class OnnxDecoder(nn.Module):
+ """A wrapper for Decoder and the decoder_proj from the joiner"""
+
+ def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
+ super().__init__()
+ self.decoder = decoder
+ self.decoder_proj = decoder_proj
+
+ def forward(self, y: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ y:
+ A 2-D tensor of shape (N, context_size).
+ Returns
+ Return a 2-D tensor of shape (N, joiner_dim)
+ """
+ need_pad = False
+ decoder_output = self.decoder(y, need_pad=need_pad)
+ decoder_output = decoder_output.squeeze(1)
+ output = self.decoder_proj(decoder_output)
+
+ return output
+
+
+class OnnxJoiner(nn.Module):
+ """A wrapper for the joiner"""
+
+ def __init__(self, output_linear: nn.Linear):
+ super().__init__()
+ self.output_linear = output_linear
+
+ def forward(
+ self,
+ encoder_out: torch.Tensor,
+ decoder_out: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ encoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ decoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ Returns:
+ Return a 2-D tensor of shape (N, vocab_size)
+ """
+ logit = encoder_out + decoder_out
+ logit = self.output_linear(torch.tanh(logit))
+ return logit
+
+
+def export_encoder_model_onnx(
+ encoder_model: OnnxEncoder,
+ encoder_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given encoder model to ONNX format.
+ The exported model has two inputs:
+
+ - x, a tensor of shape (N, T, C); dtype is torch.float32
+ - x_lens, a tensor of shape (N,); dtype is torch.int64
+
+ and it has two outputs:
+
+ - encoder_out, a tensor of shape (N, T', joiner_dim)
+ - encoder_out_lens, a tensor of shape (N,)
+
+ Args:
+ encoder_model:
+ The input encoder model
+ encoder_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ x = torch.zeros(1, 100, 80, dtype=torch.float32)
+ x_lens = torch.tensor([100], dtype=torch.int64)
+
+ torch.onnx.export(
+ encoder_model,
+ (x, x_lens),
+ encoder_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["x", "x_lens"],
+ output_names=["encoder_out", "encoder_out_lens"],
+ dynamic_axes={
+ "x": {0: "N", 1: "T"},
+ "x_lens": {0: "N"},
+ "encoder_out": {0: "N", 1: "T"},
+ "encoder_out_lens": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "zipformer",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "stateless7",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=encoder_filename, meta_data=meta_data)
+
+
+def export_decoder_model_onnx(
+ decoder_model: OnnxDecoder,
+ decoder_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the decoder model to ONNX format.
+
+ The exported model has one input:
+
+ - y: a torch.int64 tensor of shape (N, decoder_model.context_size)
+
+ and has one output:
+
+ - decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
+
+ Args:
+ decoder_model:
+ The decoder model to be exported.
+ decoder_filename:
+ Filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ context_size = decoder_model.decoder.context_size
+ vocab_size = decoder_model.decoder.vocab_size
+
+ y = torch.zeros(10, context_size, dtype=torch.int64)
+ torch.onnx.export(
+ decoder_model,
+ y,
+ decoder_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["y"],
+ output_names=["decoder_out"],
+ dynamic_axes={
+ "y": {0: "N"},
+ "decoder_out": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "context_size": str(context_size),
+ "vocab_size": str(vocab_size),
+ }
+ add_meta_data(filename=decoder_filename, meta_data=meta_data)
+
+
+def export_joiner_model_onnx(
+ joiner_model: nn.Module,
+ joiner_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the joiner model to ONNX format.
+ The exported joiner model has two inputs:
+
+ - encoder_out: a tensor of shape (N, joiner_dim)
+ - decoder_out: a tensor of shape (N, joiner_dim)
+
+ and produces one output:
+
+ - logit: a tensor of shape (N, vocab_size)
+ """
+ joiner_dim = joiner_model.output_linear.weight.shape[1]
+ logging.info(f"joiner dim: {joiner_dim}")
+
+ projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+ projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+
+ torch.onnx.export(
+ joiner_model,
+ (projected_encoder_out, projected_decoder_out),
+ joiner_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=[
+ "encoder_out",
+ "decoder_out",
+ ],
+ output_names=["logit"],
+ dynamic_axes={
+ "encoder_out": {0: "N"},
+ "decoder_out": {0: "N"},
+ "logit": {0: "N"},
+ },
+ )
+ meta_data = {
+ "joiner_dim": str(joiner_dim),
+ }
+ add_meta_data(filename=joiner_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ setup_logger(f"{params.exp_dir}/log-export/log-export-onnx")
+
+ logging.info(f"device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ convert_scaled_to_non_scaled(model, inplace=True)
+
+ encoder = OnnxEncoder(
+ encoder=model.encoder,
+ encoder_proj=model.joiner.encoder_proj,
+ )
+
+ decoder = OnnxDecoder(
+ decoder=model.decoder,
+ decoder_proj=model.joiner.decoder_proj,
+ )
+
+ joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
+
+ encoder_num_param = sum([p.numel() for p in encoder.parameters()])
+ decoder_num_param = sum([p.numel() for p in decoder.parameters()])
+ joiner_num_param = sum([p.numel() for p in joiner.parameters()])
+ total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
+ logging.info(f"encoder parameters: {encoder_num_param}")
+ logging.info(f"decoder parameters: {decoder_num_param}")
+ logging.info(f"joiner parameters: {joiner_num_param}")
+ logging.info(f"total parameters: {total_num_param}")
+
+ if params.iter > 0:
+ suffix = f"iter-{params.iter}"
+ else:
+ suffix = f"epoch-{params.epoch}"
+
+ suffix += f"-avg-{params.avg}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
+ export_encoder_model_onnx(
+ encoder,
+ encoder_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported encoder to {encoder_filename}")
+
+ logging.info("Exporting decoder")
+ decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
+ export_decoder_model_onnx(
+ decoder,
+ decoder_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported decoder to {decoder_filename}")
+
+ logging.info("Exporting joiner")
+ joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
+ export_joiner_model_onnx(
+ joiner,
+ joiner_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported joiner to {joiner_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=encoder_filename,
+ model_output=encoder_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=decoder_filename,
+ model_output=decoder_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=joiner_filename,
+ model_output=joiner_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py
new file mode 100755
index 000000000..53705321e
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/export.py
@@ -0,0 +1,321 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./pruned_transducer_stateless7/export.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --bpe-model data/en/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 5 \
+ --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./pruned_transducer_stateless7/export.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --bpe-model data/en/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 5
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `pruned_transducer_stateless7/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/commonvoice/ASR
+ ./pruned_transducer_stateless7/decode.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/en/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+ # You will find the pre-trained model in icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17/exp
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/en/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named cpu_jit.pt
+
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py
new file mode 120000
index 000000000..0d8bc665b
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py
new file mode 100755
index 000000000..19c518eaf
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py
@@ -0,0 +1,240 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script checks that exported onnx models produce the same output
+with the given torchscript model for the same input.
+
+We use the pre-trained model from
+https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-9999.pt
+popd
+
+2. Export the model via torchscript (torch.jit.script())
+
+./pruned_transducer_stateless7/export.py \
+ --bpe-model $repo/data/en/lang_bpe_500/bpe.model \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/ \
+ --jit 1
+
+It will generate the following file in $repo/exp:
+ - cpu_jit.pt
+
+3. Export the model to ONNX
+
+./pruned_transducer_stateless7/export-onnx.py \
+ --bpe-model $repo/data/en/lang_bpe_500/bpe.model \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-9999-avg-1.onnx
+ - decoder-epoch-9999-avg-1.onnx
+ - joiner-epoch-9999-avg-1.onnx
+
+4. Run this file
+
+./pruned_transducer_stateless7/onnx_check.py \
+ --jit-filename $repo/exp/cpu_jit.pt \
+ --onnx-encoder-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
+ --onnx-decoder-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
+ --onnx-joiner-filename $repo/exp/joiner-epoch-9999-avg-1.onnx
+"""
+
+import argparse
+import logging
+
+from icefall import is_module_available
+from onnx_pretrained import OnnxModel
+
+import torch
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--jit-filename",
+ required=True,
+ type=str,
+ help="Path to the torchscript model",
+ )
+
+ parser.add_argument(
+ "--onnx-encoder-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx encoder model",
+ )
+
+ parser.add_argument(
+ "--onnx-decoder-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx decoder model",
+ )
+
+ parser.add_argument(
+ "--onnx-joiner-filename",
+ required=True,
+ type=str,
+ help="Path to the onnx joiner model",
+ )
+
+ return parser
+
+
+def test_encoder(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ C = 80
+ for i in range(3):
+ N = torch.randint(low=1, high=20, size=(1,)).item()
+ T = torch.randint(low=30, high=50, size=(1,)).item()
+ logging.info(f"test_encoder: iter {i}, N={N}, T={T}")
+
+ x = torch.rand(N, T, C)
+ x_lens = torch.randint(low=30, high=T + 1, size=(N,))
+ x_lens[0] = T
+
+ torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens)
+ torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out)
+
+ onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens)
+
+ assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), (
+ (torch_encoder_out - onnx_encoder_out).abs().max()
+ )
+
+
+def test_decoder(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ context_size = onnx_model.context_size
+ vocab_size = onnx_model.vocab_size
+ for i in range(10):
+ N = torch.randint(1, 100, size=(1,)).item()
+ logging.info(f"test_decoder: iter {i}, N={N}")
+ x = torch.randint(
+ low=1,
+ high=vocab_size,
+ size=(N, context_size),
+ dtype=torch.int64,
+ )
+ torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False]))
+ torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out)
+ torch_decoder_out = torch_decoder_out.squeeze(1)
+
+ onnx_decoder_out = onnx_model.run_decoder(x)
+ assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), (
+ (torch_decoder_out - onnx_decoder_out).abs().max()
+ )
+
+
+def test_joiner(
+ torch_model: torch.jit.ScriptModule,
+ onnx_model: OnnxModel,
+):
+ encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1]
+ decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1]
+ for i in range(10):
+ N = torch.randint(1, 100, size=(1,)).item()
+ logging.info(f"test_joiner: iter {i}, N={N}")
+ encoder_out = torch.rand(N, encoder_dim)
+ decoder_out = torch.rand(N, decoder_dim)
+
+ projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out)
+ projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out)
+
+ torch_joiner_out = torch_model.joiner(encoder_out, decoder_out)
+ onnx_joiner_out = onnx_model.run_joiner(
+ projected_encoder_out, projected_decoder_out
+ )
+
+ assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), (
+ (torch_joiner_out - onnx_joiner_out).abs().max()
+ )
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ logging.info(vars(args))
+
+ torch_model = torch.jit.load(args.jit_filename)
+
+ onnx_model = OnnxModel(
+ encoder_model_filename=args.onnx_encoder_filename,
+ decoder_model_filename=args.onnx_decoder_filename,
+ joiner_model_filename=args.onnx_joiner_filename,
+ )
+
+ logging.info("Test encoder")
+ test_encoder(torch_model, onnx_model)
+
+ logging.info("Test decoder")
+ test_decoder(torch_model, onnx_model)
+
+ logging.info("Test joiner")
+ test_joiner(torch_model, onnx_model)
+ logging.info("Finished checking ONNX models")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+# See https://github.com/pytorch/pytorch/issues/38342
+# and https://github.com/pytorch/pytorch/issues/33354
+#
+# If we don't do this, the delay increases whenever there is
+# a new request that changes the actual batch size.
+# If you use `py-spy dump --pid --native`, you will
+# see a lot of time is spent in re-compiling the torch script model.
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+torch._C._set_graph_executor_optimize(False)
+if __name__ == "__main__":
+ torch.manual_seed(20220727)
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py
new file mode 100755
index 000000000..eee19191e
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py
@@ -0,0 +1,419 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX models and uses them to decode waves.
+You can use the following command to get the exported models:
+
+We use the pre-trained model from
+https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/yfyeung/icefall-asr-cv-corpus-13.0-2023-03-09-en-pruned-transducer-stateless7-2023-04-17
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-9999.pt
+popd
+
+2. Export the model to ONNX
+
+./pruned_transducer_stateless7/export-onnx.py \
+ --bpe-model $repo/data/en/lang_bpe_500/bpe.model \
+ --epoch 9999 \
+ --avg 1 \
+ --exp-dir $repo/exp/
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-9999-avg-1.onnx
+ - decoder-epoch-9999-avg-1.onnx
+ - joiner-epoch-9999-avg-1.onnx
+
+3. Run this file
+
+./pruned_transducer_stateless7/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-9999-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-9999-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-9999-avg-1.onnx \
+ --tokens $repo/data/en/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List, Tuple
+
+import k2
+import kaldifeat
+import numpy as np
+import onnxruntime as ort
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="""Path to tokens.txt.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ encoder_model_filename: str,
+ decoder_model_filename: str,
+ joiner_model_filename: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 1
+
+ self.session_opts = session_opts
+
+ self.init_encoder(encoder_model_filename)
+ self.init_decoder(decoder_model_filename)
+ self.init_joiner(joiner_model_filename)
+
+ def init_encoder(self, encoder_model_filename: str):
+ self.encoder = ort.InferenceSession(
+ encoder_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ def init_decoder(self, decoder_model_filename: str):
+ self.decoder = ort.InferenceSession(
+ decoder_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
+ self.context_size = int(decoder_meta["context_size"])
+ self.vocab_size = int(decoder_meta["vocab_size"])
+
+ logging.info(f"context_size: {self.context_size}")
+ logging.info(f"vocab_size: {self.vocab_size}")
+
+ def init_joiner(self, joiner_model_filename: str):
+ self.joiner = ort.InferenceSession(
+ joiner_model_filename,
+ sess_options=self.session_opts,
+ )
+
+ joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
+ self.joiner_dim = int(joiner_meta["joiner_dim"])
+
+ logging.info(f"joiner_dim: {self.joiner_dim}")
+
+ def run_encoder(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 2-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a tuple containing:
+ - encoder_out, its shape is (N, T', joiner_dim)
+ - encoder_out_lens, its shape is (N,)
+ """
+ out = self.encoder.run(
+ [
+ self.encoder.get_outputs()[0].name,
+ self.encoder.get_outputs()[1].name,
+ ],
+ {
+ self.encoder.get_inputs()[0].name: x.numpy(),
+ self.encoder.get_inputs()[1].name: x_lens.numpy(),
+ },
+ )
+ return torch.from_numpy(out[0]), torch.from_numpy(out[1])
+
+ def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ decoder_input:
+ A 2-D tensor of shape (N, context_size)
+ Returns:
+ Return a 2-D tensor of shape (N, joiner_dim)
+ """
+ out = self.decoder.run(
+ [self.decoder.get_outputs()[0].name],
+ {self.decoder.get_inputs()[0].name: decoder_input.numpy()},
+ )[0]
+
+ return torch.from_numpy(out)
+
+ def run_joiner(
+ self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ encoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ decoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ Returns:
+ Return a 2-D tensor of shape (N, vocab_size)
+ """
+ out = self.joiner.run(
+ [self.joiner.get_outputs()[0].name],
+ {
+ self.joiner.get_inputs()[0].name: encoder_out.numpy(),
+ self.joiner.get_inputs()[1].name: decoder_out.numpy(),
+ },
+ )[0]
+
+ return torch.from_numpy(out)
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+def greedy_search(
+ model: OnnxModel,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ A 3-D tensor of shape (N, T, joiner_dim)
+ encoder_out_lens:
+ A 1-D tensor of shape (N,).
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ assert encoder_out.ndim == 3, encoder_out.shape
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ blank_id = 0 # hard-code to 0
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ context_size = model.context_size
+ hyps = [[blank_id] * context_size for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.run_decoder(decoder_input)
+
+ offset = 0
+ for batch_size in batch_size_list:
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = packed_encoder_out.data[start:end]
+ # current_encoder_out's shape: (batch_size, joiner_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+ logits = model.run_joiner(current_encoder_out, decoder_out)
+
+ # logits'shape (batch_size, vocab_size)
+
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v != blank_id:
+ hyps[i].append(v)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ dtype=torch.int64,
+ )
+ decoder_out = model.run_decoder(decoder_input)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = args.sample_rate
+ opts.mel_opts.num_bins = 80
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ expected_sample_rate=args.sample_rate,
+ )
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
+ encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths)
+
+ hyps = greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ s = "\n"
+
+ symbol_table = k2.SymbolTable.from_file(args.tokens)
+
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += symbol_table[i]
+ return text.replace("▁", " ").strip()
+
+ for filename, hyp in zip(args.sound_files, hyps):
+ words = token_ids_to_words(hyp)
+ s += f"{filename}:\n{words}\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py
new file mode 100755
index 000000000..a22d1b4ba
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+./pruned_transducer_stateless7/export.py \
+ --exp-dir ./pruned_transducer_stateless7/exp \
+ --bpe-model data/en/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 5
+
+Usage of this script:
+
+(1) greedy search
+./pruned_transducer_stateless7/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+ --bpe-model ./data/en/lang_bpe_500/bpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless7/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+ --bpe-model ./data/en/lang_bpe_500/bpe.model \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless7/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+ --bpe-model ./data/en/lang_bpe_500/bpe.model \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless7/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
+ --bpe-model ./data/en/lang_bpe_500/bpe.model \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
+./pruned_transducer_stateless7/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to bpe.model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.method}"
+ if params.method == "beam_search":
+ msg += f" with beam size {params.beam_size}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ hyps.append(sp.decode(hyp).split())
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
new file mode 100755
index 000000000..73a29a90a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
@@ -0,0 +1,1250 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/en/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ train_cuts = commonvoice.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = commonvoice.dev_cuts()
+ valid_dl = commonvoice.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py
new file mode 120000
index 000000000..f2f66041e
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/shared b/egs/commonvoice/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/commonvoice/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/csj/ASR/.gitignore b/egs/csj/ASR/.gitignore
index 5d965832e..cd0e20c4c 100644
--- a/egs/csj/ASR/.gitignore
+++ b/egs/csj/ASR/.gitignore
@@ -5,4 +5,4 @@ notify_tg.py
finetune_*
misc.ini
.vscode/*
-offline/*
\ No newline at end of file
+offline/*
diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md
new file mode 100644
index 000000000..95c2ec6ac
--- /dev/null
+++ b/egs/csj/ASR/README.md
@@ -0,0 +1,11 @@
+# Introduction
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+These are the types of architectures currently available.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
+| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming |
diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md
new file mode 100644
index 000000000..56fdb899f
--- /dev/null
+++ b/egs/csj/ASR/RESULTS.md
@@ -0,0 +1,200 @@
+# Results
+
+## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
+
+### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
+
+See for more details.
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+
+
+Number of model parameters: 75688409, i.e. 75.7M.
+
+#### training on disfluent transcript
+
+The CERs are:
+
+| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
+| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
+| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming |
+| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise |
+| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming |
+| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise |
+| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming |
+| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise |
+| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming |
+| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise |
+| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming |
+| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise |
+| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming |
+| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise |
+
+Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
+while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
+
+The training command was:
+```bash
+./pruned_transducer_stateless7_streaming/train.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
+ --max-duration 375 \
+ --transcript-mode disfluent \
+ --lang data/lang_char \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --pad-feature 30 \
+ --musan-dir /mnt/host/corpus/musan/musan/fbank
+```
+
+The simulated streaming decoding command was:
+```bash
+for chunk in 64 32; do
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ python pruned_transducer_stateless7_streaming/decode.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
+ --epoch 30 \
+ --avg 17 \
+ --max-duration 350 \
+ --decoding-method $m \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --lang data/lang_char \
+ --transcript-mode disfluent \
+ --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \
+ --decode-chunk-len $chunk \
+ --pad-feature 30 \
+ --gpu 0
+ done
+done
+```
+
+The streaming chunk-wise decoding command was:
+```bash
+for chunk in 64 32; do
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ python pruned_transducer_stateless7_streaming/streaming_decode.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
+ --epoch 30 \
+ --avg 17 \
+ --max-duration 350 \
+ --decoding-method $m \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --lang data/lang_char \
+ --transcript-mode disfluent \
+ --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \
+ --decode-chunk-len $chunk \
+ --gpu 2 \
+ --num-decode-streams 40
+ done
+done
+```
+
+#### training on fluent transcript
+
+The CERs are:
+
+| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
+| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
+| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming |
+| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise |
+| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming |
+| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise |
+| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming |
+| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise |
+| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming |
+| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise |
+| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming |
+| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise |
+| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming |
+| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise |
+
+Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
+while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
+
+The training command was:
+```bash
+./pruned_transducer_stateless7_streaming/train.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
+ --max-duration 375 \
+ --transcript-mode fluent \
+ --lang data/lang_char \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --pad-feature 30 \
+ --musan-dir /mnt/host/corpus/musan/musan/fbank
+```
+
+The simulated streaming decoding command was:
+```bash
+for chunk in 64 32; do
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ python pruned_transducer_stateless7_streaming/decode.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
+ --epoch 30 \
+ --avg 12 \
+ --max-duration 350 \
+ --decoding-method $m \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --lang data/lang_char \
+ --transcript-mode fluent \
+ --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \
+ --decode-chunk-len $chunk \
+ --pad-feature 30 \
+ --gpu 1
+ done
+done
+```
+
+The streaming chunk-wise decoding command was:
+```bash
+for chunk in 64 32; do
+ for m in greedy_search fast_beam_search modified_beam_search; do
+ python pruned_transducer_stateless7_streaming/streaming_decode.py \
+ --feedforward-dims "1024,1024,2048,2048,1024" \
+ --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
+ --epoch 30 \
+ --avg 12 \
+ --max-duration 350 \
+ --decoding-method $m \
+ --manifest-dir /mnt/host/corpus/csj/fbank \
+ --lang data/lang_char \
+ --transcript-mode fluent \
+ --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \
+ --decode-chunk-len $chunk \
+ --gpu 3 \
+ --num-decode-streams 40
+ done
+done
+```
+
+#### Comparing disfluent to fluent
+
+$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$
+
+This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared.
+
+| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode |
+| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- |
+| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming |
+| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise |
+| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming |
+| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise |
+| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming |
+| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise |
+| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming |
+| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise |
+| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming |
+| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise |
+| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming |
+| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise |
+| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | |
diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py
new file mode 100644
index 000000000..f6b4b2caf
--- /dev/null
+++ b/egs/csj/ASR/local/add_transcript_mode.py
@@ -0,0 +1,94 @@
+import argparse
+import logging
+from configparser import ConfigParser
+from pathlib import Path
+from typing import List
+
+from lhotse import CutSet, SupervisionSet
+from lhotse.recipes.csj import CSJSDBParser
+
+ARGPARSE_DESCRIPTION = """
+This script adds transcript modes to an existing CutSet or SupervisionSet.
+"""
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ description=ARGPARSE_DESCRIPTION,
+ )
+ parser.add_argument(
+ "-f",
+ "--fbank-dir",
+ type=Path,
+ help="Path to directory where manifests are stored.",
+ )
+ parser.add_argument(
+ "-c",
+ "--config",
+ type=Path,
+ nargs="+",
+ help="Path to config file for transcript parsing.",
+ )
+ return parser.parse_args()
+
+
+def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]:
+ parsers = []
+ for config_file in config_files:
+ config = ConfigParser()
+ config.optionxform = str
+ assert config.read(config_file), f"{config_file} could not be found."
+ decisions = {}
+ for k, v in config["DECISIONS"].items():
+ try:
+ decisions[k] = int(v)
+ except ValueError:
+ decisions[k] = v
+ parsers.append(
+ (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions))
+ )
+ return parsers
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(
+ format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
+ level=logging.INFO,
+ )
+ parsers = get_CSJParsers(args.config)
+ config = ConfigParser()
+ config.optionxform = str
+ assert config.read(args.config), args.config
+ decisions = {}
+ for k, v in config["DECISIONS"].items():
+ try:
+ decisions[k] = int(v)
+ except ValueError:
+ decisions[k] = v
+
+ logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.")
+
+ manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz")
+ assert manifests, f"No cuts to be found in {args.fbank_dir}"
+
+ for manifest in manifests:
+ results = []
+ logging.info(f"Adding transcript modes to {manifest.name} now.")
+ cutset = CutSet.from_file(manifest)
+ for cut in cutset:
+ for name, parser in parsers:
+ cut.supervisions[0].custom[name] = parser.parse(
+ cut.supervisions[0].custom["raw"]
+ )
+ cut.supervisions[0].text = ""
+ results.append(cut)
+ results = CutSet.from_items(results)
+ res_file = manifest.as_posix()
+ manifest.replace(manifest.parent / ("bak." + manifest.name))
+ results.to_file(res_file)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py
index 994dedbdd..ce560025d 100644
--- a/egs/csj/ASR/local/compute_fbank_csj.py
+++ b/egs/csj/ASR/local/compute_fbank_csj.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
+# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -19,36 +19,23 @@
import argparse
import logging
import os
-from itertools import islice
from pathlib import Path
-from random import Random
from typing import List, Tuple
import torch
-from lhotse import (
+
+# fmt: off
+from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527
CutSet,
Fbank,
FbankConfig,
- # fmt: off
- # See the following for why LilcomChunkyWriter is preferred
- # https://github.com/k2-fsa/icefall/pull/404
- # https://github.com/lhotse-speech/lhotse/pull/527
- # fmt: on
LilcomChunkyWriter,
RecordingSet,
SupervisionSet,
)
+from lhotse.recipes.csj import concat_csj_supervisions
-ARGPARSE_DESCRIPTION = """
-This script follows the espnet method of splitting the remaining core+noncore
-utterances into valid and train cutsets at an index which is by default 4000.
-
-In other words, the core+noncore utterances are shuffled, where 4000 utterances
-of the shuffled set go to the `valid` cutset and are not subject to speed
-perturbation. The remaining utterances become the `train` cutset and are speed-
-perturbed (0.9x, 1.0x, 1.1x).
-
-"""
+# fmt: on
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@@ -58,78 +45,100 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
RNG_SEED = 42
+# concat_params_train = [
+# {"gap": 1.0, "maxlen": 10.0},
+# {"gap": 1.5, "maxlen": 8.0},
+# {"gap": 1.0, "maxlen": 18.0},
+# ]
+
+concat_params = {"gap": 1.0, "maxlen": 10.0}
def make_cutset_blueprints(
manifest_dir: Path,
- split: int,
) -> List[Tuple[str, CutSet]]:
cut_sets = []
+ logging.info("Creating non-train cuts.")
+
# Create eval datasets
- logging.info("Creating eval cuts.")
for i in range(1, 4):
+ sps = sorted(
+ SupervisionSet.from_file(
+ manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
+ ),
+ key=lambda x: x.id,
+ )
+
cut_set = CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / f"csj_recordings_eval{i}.jsonl.gz"
),
- supervisions=SupervisionSet.from_file(
- manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
- ),
+ supervisions=concat_csj_supervisions(sps, **concat_params),
)
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_sets.append((f"eval{i}", cut_set))
- # Create train and valid cuts
- logging.info(
- "Loading, trimming, and shuffling the remaining core+noncore cuts."
+ # Create excluded dataset
+ sps = sorted(
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"),
+ key=lambda x: x.id,
)
- recording_set = RecordingSet.from_file(
- manifest_dir / "csj_recordings_core.jsonl.gz"
- ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
- supervision_set = SupervisionSet.from_file(
- manifest_dir / "csj_supervisions_core.jsonl.gz"
- ) + SupervisionSet.from_file(
- manifest_dir / "csj_supervisions_noncore.jsonl.gz"
- )
-
cut_set = CutSet.from_manifests(
- recordings=recording_set,
- supervisions=supervision_set,
+ recordings=RecordingSet.from_file(
+ manifest_dir / "csj_recordings_excluded.jsonl.gz"
+ ),
+ supervisions=concat_csj_supervisions(sps, **concat_params),
)
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
- cut_set = cut_set.shuffle(Random(RNG_SEED))
+ cut_sets.append(("excluded", cut_set))
- logging.info(
- "Creating valid and train cuts from core and noncore,"
- f"split at {split}."
+ # Create valid dataset
+ sps = sorted(
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"),
+ key=lambda x: x.id,
)
- valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
+ cut_set = CutSet.from_manifests(
+ recordings=RecordingSet.from_file(
+ manifest_dir / "csj_recordings_valid.jsonl.gz"
+ ),
+ supervisions=concat_csj_supervisions(sps, **concat_params),
+ )
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
+ cut_sets.append(("valid", cut_set))
- train_set = CutSet.from_cuts(islice(cut_set, split, None))
- train_set = (
- train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+ logging.info("Creating train cuts.")
+
+ # Create train dataset
+ sps = sorted(
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz")
+ + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"),
+ key=lambda x: x.id,
)
- cut_sets.extend([("valid", valid_set), ("train", train_set)])
+ recording = RecordingSet.from_file(
+ manifest_dir / "csj_recordings_core.jsonl.gz"
+ ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
+
+ train_set = CutSet.from_manifests(
+ recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params)
+ ).trim_to_supervisions(keep_overlapping=False)
+ train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
+
+ cut_sets.append(("train", train_set))
return cut_sets
def get_args():
parser = argparse.ArgumentParser(
- description=ARGPARSE_DESCRIPTION,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-
parser.add_argument(
- "--manifest-dir", type=Path, help="Path to save manifests"
+ "-m", "--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
- "--fbank-dir", type=Path, help="Path to save fbank features"
- )
- parser.add_argument(
- "--split", type=int, default=4000, help="Split at this index"
+ "-f", "--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args()
@@ -141,9 +150,7 @@ def main():
extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count())
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
@@ -154,7 +161,7 @@ def main():
)
return
else:
- cut_sets = make_cutset_blueprints(args.manifest_dir, args.split)
+ cut_sets = make_cutset_blueprints(args.manifest_dir)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
cut_set = cut_set.compute_and_store_features(
@@ -163,7 +170,7 @@ def main():
storage_path=(args.fbank_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
- cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz")
+ cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz")
logging.info("All fbank computed for CSJ.")
(args.fbank_dir / ".done").touch()
diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py
index 44a33c4eb..c942df98e 100644
--- a/egs/csj/ASR/local/compute_fbank_musan.py
+++ b/egs/csj/ASR/local/compute_fbank_musan.py
@@ -26,12 +26,9 @@ from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
-
ARGPARSE_DESCRIPTION = """
This file computes fbank features of the musan dataset.
-It looks for manifests in the directory data/manifests.
-The generated fbank features are saved in data/fbank.
"""
# Torch's multithreaded behavior needs to be disabled or
@@ -43,8 +40,6 @@ torch.set_num_interop_threads(1)
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
- # src_dir = Path("data/manifests")
- # output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
@@ -84,9 +79,7 @@ def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
- recordings=combine(
- part["recordings"] for part in manifests.values()
- )
+ recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
@@ -108,10 +101,10 @@ def get_args():
)
parser.add_argument(
- "--manifest-dir", type=Path, help="Path to save manifests"
+ "-m", "--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
- "--fbank-dir", type=Path, help="Path to save fbank features"
+ "-f", "--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args()
@@ -119,9 +112,7 @@ def get_args():
if __name__ == "__main__":
args = get_args()
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan(args.manifest_dir, args.fbank_dir)
diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini
index eb70673de..4f0a9ec0e 100644
--- a/egs/csj/ASR/local/conf/disfluent.ini
+++ b/egs/csj/ASR/local/conf/disfluent.ini
@@ -1,321 +1,79 @@
-; # This section is ignored if this file is not supplied as the first config file to
-; # lhotse prepare csj
-[SEGMENTS]
-; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
-gap = 0.5
-; # Maximum length of segment (s).
-maxlen = 10
-; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
-minlen = 0.02
-; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
-; # Pass an empty string to avoid adding any symbol. It was "" in kaldi.
-; # If you intend to use a multicharacter string for gap_sym, remember to register the
-; # multicharacter string as part of userdef-string in prepare_lang_char.py.
-gap_sym =
-
[CONSTANTS]
; # Name of this mode
MODE = disfluent
-; # Suffixes to use after the word surface (no longer used)
-MORPH = pos1 cForm cType2 pos2
-; # Used to differentiate between A tag and A_num tag
-JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
-; # Dummy character to delineate multiline words
-PLUS = +
[DECISIONS]
-; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
-; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
-
; # フィラー、感情表出系感動詞
; # 0 to remain, 1 to delete
; # Example: '(F ぎょっ)'
F = 0
-; # Example: '(L (F ン))', '比べ(F えー)る'
-F^ = 0
; # 言い直し、いいよどみなどによる語断片
; # 0 to remain, 1 to delete
; # Example: '(D だ)(D だいが) 大学の学部の会議'
D = 0
-; # Example: '(L (D ドゥ)+(D ヒ))'
-D^ = 0
; # 助詞、助動詞、接辞の言い直し
; # 0 to remain, 1 to delete
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
D2 = 0
-; # Example: '(X (D2 ノ))'
-D2^ = 0
; # 聞き取りや語彙の判断に自信がない場合
; # 0 to remain, 1 to delete
; # Example: (? 字数) の
; # If no option: empty string is returned regardless of output
; # Example: '(?) で'
? = 0
-; # Example: '(D (? すー))+そう+です+よ+ね'
-?^ = 0
; # タグ?で、値は複数の候補が想定される場合
; # 0 for main guess with matching morph info, 1 for second guess
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
?, = 0
-; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
-?,^ = 0
; # 音や言葉に関するメタ的な引用
; # 0 to remain, 1 to delete
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
M = 0
-; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
-M^ = 0
; # 外国語や古語、方言など
; # 0 to remain, 1 to delete
; # Example: '(O ザッツファイン)'
O = 0
-; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
-O^ = 0
; # 講演者の名前、差別語、誹謗中傷など
; # 0 to remain, 1 to delete
; # Example: '国語研の (R ××) です'
R = 0
-R^ = 0
; # 非朗読対象発話(朗読における言い間違い等)
; # 0 to remain, 1 to delete
; # Example: '(X 実際は) 実際には'
X = 0
-; # Example: '(L (X (D2 ニ)))'
-X^ = 0
; # アルファベットや算用数字、記号の表記
; # 0 to use Japanese form, 1 to use alphabet form
; # Example: '(A シーディーアール;CD-R)'
A = 1
-; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
-A^ = 1
; # タグAで、単語は算用数字の場合
; # 0 to use Japanese form, 1 to use Arabic numerals
; # Example: (A 二千;2000)
-A_num = eval:self.notag
-A_num^ = eval:self.notag
+A_num = 0
; # 何らかの原因で漢字表記できなくなった場合
; # 0 to use broken form, 1 to use orthodox form
; # Example: '(K たち (F えー) ばな;橘)'
K = 1
-; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
-K^ = 1
; # 転訛、発音の怠けなど、一時的な発音エラー
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(W ギーツ;ギジュツ)'
W = 1
-; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
-W^ = 1
; # 語の読みに関する知識レベルのいい間違い
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(B シブタイ;ジュータイ)'
B = 0
-; # Example: 'データー(B カズ;スー)'
-B^ = 0
; # 笑いながら発話
; # 0 to remain, 1 to delete
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
笑 = 0
-; # Example: 'コク(笑 サイ+(D オン))',
-笑^ = 0
; # 泣きながら発話
; # 0 to remain, 1 to delete
-; # Example: '(泣 ドンナニ)'
+; # Example: '(泣 ドンナニ)'
泣 = 0
-泣^ = 0
; # 咳をしながら発話
; # 0 to remain, 1 to delete
-; # Example: 'シャ(咳 リン) ノ'
+; # Example: 'シャ(咳 リン) ノ'
咳 = 0
-; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
-咳^ = 0
; # ささやき声や独り言などの小さな声
; # 0 to remain, 1 to delete
-; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
+; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
L = 0
-; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
-L^ = 0
-
-[REPLACEMENTS]
-; # ボーカルフライなどで母音が同定できない場合
- =
-; # 「うん/うーん/ふーん」の音の特定が困難な場合
- =
-; # 非語彙的な母音の引き延ばし
-