diff --git a/.flake8 b/.flake8
index 19c3a9bd6..cbf0d8484 100644
--- a/.flake8
+++ b/.flake8
@@ -4,10 +4,20 @@ statistics=true
max-line-length = 80
per-file-ignores =
# line too long
- egs/librispeech/ASR/*/conformer.py: E501,
- egs/aishell/ASR/*/conformer.py: E501,
+ icefall/diagnostics.py: E501,
+ egs/*/ASR/*/conformer.py: E501,
+ egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
+ egs/*/ASR/*/optim.py: E501,
+ egs/*/ASR/*/scaling.py: E501,
+ egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203,
+ egs/librispeech/ASR/conformer_ctc2/*py: E501,
+ egs/librispeech/ASR/RESULTS.md: E999,
+
+ # invalid escape sequence (cause by tex formular), W605
+ icefall/utils.py: E501, W605
exclude =
.git,
**/data/**,
- icefall/shared/make_kn_lm.py
+ icefall/shared/make_kn_lm.py,
+ icefall/__init__.py
diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
new file mode 100755
index 000000000..a4a6cd8d7
--- /dev/null
+++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+
+# This script computes fbank features for the test-clean and test-other datasets.
+# The computed features are saved to ~/tmp/fbank-libri and are
+# cached for later runs
+
+export PYTHONPATH=$PWD:$PYTHONPATH
+echo $PYTHONPATH
+
+mkdir ~/tmp/fbank-libri
+cd egs/librispeech/ASR
+mkdir -p data
+cd data
+[ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank
+cd ..
+./local/compute_fbank_librispeech.py
+ls -lh data/fbank/
diff --git a/.github/scripts/download-gigaspeech-dev-test-dataset.sh b/.github/scripts/download-gigaspeech-dev-test-dataset.sh
new file mode 100755
index 000000000..b9464de9f
--- /dev/null
+++ b/.github/scripts/download-gigaspeech-dev-test-dataset.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+# This script downloads the pre-computed fbank features for
+# dev and test datasets of GigaSpeech.
+#
+# You will find directories `~/tmp/giga-dev-dataset-fbank` after running
+# this script.
+
+mkdir -p ~/tmp
+cd ~/tmp
+
+git lfs install
+git clone https://huggingface.co/csukuangfj/giga-dev-dataset-fbank
+
+ls -lh giga-dev-dataset-fbank/data/fbank
diff --git a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
new file mode 100755
index 000000000..3efcc13e3
--- /dev/null
+++ b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+
+# This script downloads the test-clean and test-other datasets
+# of LibriSpeech and unzip them to the folder ~/tmp/download,
+# which is cached by GitHub actions for later runs.
+#
+# You will find directories ~/tmp/download/LibriSpeech after running
+# this script.
+
+mkdir ~/tmp/download
+cd egs/librispeech/ASR
+ln -s ~/tmp/download .
+cd download
+wget -q --no-check-certificate https://www.openslr.org/resources/12/test-clean.tar.gz
+tar xf test-clean.tar.gz
+rm test-clean.tar.gz
+
+wget -q --no-check-certificate https://www.openslr.org/resources/12/test-other.tar.gz
+tar xf test-other.tar.gz
+rm test-other.tar.gz
+pwd
+ls -lh
+ls -lh LibriSpeech
diff --git a/.github/scripts/install-kaldifeat.sh b/.github/scripts/install-kaldifeat.sh
new file mode 100755
index 000000000..6666a5064
--- /dev/null
+++ b/.github/scripts/install-kaldifeat.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+
+# This script installs kaldifeat into the directory ~/tmp/kaldifeat
+# which is cached by GitHub actions for later runs.
+
+mkdir -p ~/tmp
+cd ~/tmp
+git clone https://github.com/csukuangfj/kaldifeat
+cd kaldifeat
+mkdir build
+cd build
+cmake -DCMAKE_BUILD_TYPE=Release ..
+make -j2 _kaldifeat
diff --git a/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
new file mode 100755
index 000000000..e0b87e0fc
--- /dev/null
+++ b/.github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+
+# This script assumes that test-clean and test-other are downloaded
+# to egs/librispeech/ASR/download/LibriSpeech and generates manifest
+# files in egs/librispeech/ASR/data/manifests
+
+cd egs/librispeech/ASR
+[ ! -e download ] && ln -s ~/tmp/download .
+mkdir -p data/manifests
+lhotse prepare librispeech -j 2 -p test-clean -p test-other ./download/LibriSpeech data/manifests
+ls -lh data/manifests
diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
new file mode 100755
index 000000000..631707ad9
--- /dev/null
+++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
@@ -0,0 +1,86 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/aishell/ASR
+
+git lfs install
+
+fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+log "Downloading pre-commputed fbank from $fbank_url"
+
+git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+ln -s $PWD/aishell-test-dev-manifests/data .
+
+log "Downloading pre-trained model from $repo_url"
+repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $rep/test_wavs/BAC009S0764W0123.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $rep/test_wavs/BAC009S0764W0123.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_char data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless3/exp
+
+ log "Decoding test and dev"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless3/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless3/exp
+ done
+
+ rm pruned_transducer_stateless3/exp/*.pt
+fi
diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh
new file mode 100755
index 000000000..528d04cd1
--- /dev/null
+++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/gigaspeech/ASR
+
+repo_url=https://huggingface.co/wgb14/icefall-asr-gigaspeech-pruned-transducer-stateless2
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless2/exp
+ ln -s $PWD/$repo/exp/pretrained-iter-3488000-avg-20.pt pruned_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh data/lang_bpe_500
+ ls -lh data/fbank
+ ls -lh pruned_transducer_stateless2/exp
+
+ log "Decoding dev and test"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ # Test only greedy_search to reduce CI running time
+ # for method in greedy_search fast_beam_search modified_beam_search; do
+ for method in greedy_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless2/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless2/exp
+ done
+
+ rm pruned_transducer_stateless2/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
new file mode 100755
index 000000000..bd816c2d6
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless/exp
+ done
+
+ rm pruned_transducer_stateless/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
new file mode 100755
index 000000000..6b5b51bd7
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
@@ -0,0 +1,80 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-epoch-38-avg-10.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless2/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless2/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless2/exp
+ done
+
+ rm pruned_transducer_stateless2/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
new file mode 100755
index 000000000..62ea02c47
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
@@ -0,0 +1,80 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-epoch-25-avg-6.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless3/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless3/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless3/exp
+ done
+
+ rm pruned_transducer_stateless3/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
new file mode 100755
index 000000000..3617bc369
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
@@ -0,0 +1,80 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless3/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless3/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless3/exp
+ done
+
+ rm pruned_transducer_stateless3/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
new file mode 100755
index 000000000..c893bc45a
--- /dev/null
+++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
@@ -0,0 +1,99 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-epoch-39-avg-7.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless5/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless5/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless5/exp
+ ln -s $PWD/$repo/exp/pretrained-epoch-39-avg-7.pt pruned_transducer_stateless5/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless5/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless5/decode.py \
+ --decoding-method $method \
+ --use-averaged-model 0 \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+ done
+
+ rm pruned_transducer_stateless5/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
new file mode 100755
index 000000000..d9dc34e48
--- /dev/null
+++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
@@ -0,0 +1,100 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+pushd $repo/exp
+ln -s pretrained-epoch-24-avg-10.pt pretrained.pt
+popd
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless2/exp
+ ln -s $PWD/$repo/exp/pretrained-epoch-24-avg-10.pt pruned_transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Simulate streaming decoding with $method"
+
+ ./pruned_transducer_stateless2/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --simulate-streaming 1 \
+ --causal-convolution 1
+ done
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Real streaming decoding with $method"
+
+ ./pruned_transducer_stateless2/streaming_decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --num-decode-streams 100 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --left-context 32 \
+ --decode-chunk-size 8 \
+ --right-context 0
+ done
+
+ rm pruned_transducer_stateless2/exp/*.pt
+fi
diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
new file mode 100755
index 000000000..c22660d0a
--- /dev/null
+++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p transducer_stateless2/exp
+ ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh transducer_stateless2/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./transducer_stateless2/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir transducer_stateless2/exp
+ done
+
+ rm transducer_stateless2/exp/*.pt
+fi
diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh
new file mode 100755
index 000000000..96a072c46
--- /dev/null
+++ b/.github/scripts/run-pre-trained-conformer-ctc.sh
@@ -0,0 +1,46 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
+git lfs install
+git clone $repo
+
+log "Downloading pre-trained model from $repo_url"
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.flac
+ls -lh $repo/test_wavs/*.flac
+
+log "CTC decoding"
+
+./conformer_ctc/pretrained.py \
+ --method ctc-decoding \
+ --num-classes 500 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.flac \
+ $repo/test_wavs/1221-135766-0001.flac \
+ $repo/test_wavs/1221-135766-0002.flac
+
+log "HLG decoding"
+
+./conformer_ctc/pretrained.py \
+ --method 1best \
+ --num-classes 500 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ $repo/test_wavs/1089-134686-0001.flac \
+ $repo/test_wavs/1221-135766-0001.flac \
+ $repo/test_wavs/1221-135766-0002.flac
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
new file mode 100755
index 000000000..dcc99d62e
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p transducer_stateless_multi_datasets/exp
+ ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh transducer_stateless_multi_datasets/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./transducer_stateless_multi_datasets/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir transducer_stateless_multi_datasets/exp
+ done
+
+ rm transducer_stateless_multi_datasets/exp/*.pt
+fi
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
new file mode 100755
index 000000000..9622224c9
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p transducer_stateless_multi_datasets/exp
+ ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh transducer_stateless_multi_datasets/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./transducer_stateless_multi_datasets/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir transducer_stateless_multi_datasets/exp
+ done
+
+ rm transducer_stateless_multi_datasets/exp/*.pt
+fi
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
new file mode 100755
index 000000000..168aee766
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
@@ -0,0 +1,47 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/aishell/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_modified-2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+done
+
+for method in modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless_modified-2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+done
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
new file mode 100755
index 000000000..9211b22eb
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
@@ -0,0 +1,47 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/aishell/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_modified/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+done
+
+for method in modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless_modified/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+done
diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh
new file mode 100755
index 000000000..4a1dc1a7e
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer-stateless.sh
@@ -0,0 +1,76 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+done
+
+echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p transducer_stateless/exp
+ ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_bpe_500 data/
+
+ ls -lh data
+ ls -lh transducer_stateless/exp
+
+ log "Decoding test-clean and test-other"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./transducer_stateless/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir transducer_stateless/exp
+ done
+
+ rm transducer_stateless/exp/*.pt
+fi
diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh
new file mode 100755
index 000000000..5f8a5b3a5
--- /dev/null
+++ b/.github/scripts/run-pre-trained-transducer.sh
@@ -0,0 +1,32 @@
+#!/usr/bin/env bash
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+soxi $repo/test_wavs/*.wav
+ls -lh $repo/test_wavs/*.wav
+
+log "Beam search decoding"
+
+./transducer/pretrained.py \
+ --method beam_search \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
new file mode 100644
index 000000000..dd0969f51
--- /dev/null
+++ b/.github/workflows/build-doc.yml
@@ -0,0 +1,65 @@
+# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# refer to https://github.com/actions/starter-workflows/pull/47/files
+
+# You can access it at https://k2-fsa.github.io/icefall/
+name: Generate doc
+on:
+ push:
+ branches:
+ - master
+ - doc
+ pull_request:
+ types: [labeled]
+
+jobs:
+ build-doc:
+ if: github.event.label.name == 'doc' || github.event_name == 'push'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.8"]
+ steps:
+ # refer to https://github.com/actions/checkout
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Display Python version
+ run: python -c "import sys; print(sys.version)"
+
+ - name: Build doc
+ shell: bash
+ run: |
+ cd docs
+ python3 -m pip install -r ./requirements.txt
+ make html
+ touch build/html/.nojekyll
+
+ - name: Deploy
+ uses: peaceiris/actions-gh-pages@v3
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ publish_dir: ./docs/build/html
+ publish_branch: gh-pages
diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml
new file mode 100644
index 000000000..e684e598e
--- /dev/null
+++ b/.github/workflows/run-aishell-2022-06-20.yml
@@ -0,0 +1,119 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-aishell-2022-06-20
+# pruned RNN-T + reworked model with random combiner
+# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_aishell_2022_06_20:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
+
+ - name: Display decoding results for aishell pruned_transducer_stateless3
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/aishell/ASR/
+ tree ./pruned_transducer_stateless3/exp
+
+ cd pruned_transducer_stateless3
+ echo "results for pruned_transducer_stateless3"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
+
+ - name: Upload decoding results for aishell pruned_transducer_stateless3
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
+ path: egs/aishell/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml
new file mode 100644
index 000000000..dc33751d3
--- /dev/null
+++ b/.github/workflows/run-gigaspeech-2022-05-13.yml
@@ -0,0 +1,122 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-gigaspeech-2022-05-13
+# stateless transducer + k2 pruned rnnt-loss + reworked conformer
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_gigaspeech_2022_05_13:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Download GigaSpeech dev/test dataset
+ shell: bash
+ run: |
+ sudo apt-get install -y -q git-lfs
+
+ .github/scripts/download-gigaspeech-dev-test-dataset.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ ln -s ~/tmp/giga-dev-dataset-fbank/data egs/gigaspeech/ASR/
+
+ ls -lh egs/gigaspeech/ASR/data/fbank
+
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh
+
+ - name: Display decoding results for gigaspeech pruned_transducer_stateless2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/gigaspeech/ASR/
+ tree ./pruned_transducer_stateless2/exp
+
+ sudo apt-get -qq install tree
+
+ cd pruned_transducer_stateless2
+ echo "results for pruned_transducer_stateless2"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
+
+ - name: Upload decoding results for gigaspeech pruned_transducer_stateless2
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
+ path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/
diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml
new file mode 100644
index 000000000..291f2bc71
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-03-12.yml
@@ -0,0 +1,155 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-03-12
+# stateless transducer + k2 pruned rnnt-loss
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_03_12:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
+
+ - name: Display decoding results for pruned_transducer_stateless
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless/exp
+
+ cd pruned_transducer_stateless
+ echo "results for pruned_transducer_stateless"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for pruned_transducer_stateless
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
+ path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
new file mode 100644
index 000000000..b04718f86
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-04-29.yml
@@ -0,0 +1,181 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-04-29
+# stateless pruned transducer (reworked model) + giga speech
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_04_29:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
+
+ - name: Display decoding results for pruned_transducer_stateless2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR
+ tree pruned_transducer_stateless2/exp
+ cd pruned_transducer_stateless2/exp
+ echo "===greedy search==="
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Display decoding results for pruned_transducer_stateless3
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR
+ tree pruned_transducer_stateless3/exp
+ cd pruned_transducer_stateless3/exp
+ echo "===greedy search==="
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for pruned_transducer_stateless2
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
+ path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
+
+ - name: Upload decoding results for pruned_transducer_stateless3
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
+ path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml
new file mode 100644
index 000000000..bb3d74e55
--- /dev/null
+++ b/.github/workflows/run-librispeech-2022-05-13.yml
@@ -0,0 +1,155 @@
+# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-05-13
+# stateless transducer + k2 pruned rnnt-loss + deeper model
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_05_13:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
+
+ - name: Display decoding results for librispeech pruned_transducer_stateless5
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless5/exp
+
+ cd pruned_transducer_stateless5
+ echo "results for pruned_transducer_stateless5"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for librispeech pruned_transducer_stateless5
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless5-2022-05-13
+ path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
new file mode 100644
index 000000000..47976fc2c
--- /dev/null
+++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
@@ -0,0 +1,153 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-pruned-transducer-stateless3-2022-05-13
+# stateless pruned transducer (reworked model) + giga speech
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_pruned_transducer_stateless3_2022_05_13:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
+
+ - name: Display decoding results for pruned_transducer_stateless3
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR
+ tree pruned_transducer_stateless3/exp
+ cd pruned_transducer_stateless3/exp
+ echo "===greedy search==="
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for pruned_transducer_stateless3
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
+ path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
new file mode 100644
index 000000000..9ce8244da
--- /dev/null
+++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
@@ -0,0 +1,155 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-streaming-2022-06-26
+# streaming conformer stateless transducer2
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_streaming_2022_06_26:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
+
+ - name: Display decoding results
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./pruned_transducer_stateless2/exp
+
+ cd pruned_transducer_stateless2
+ echo "results for pruned_transducer_stateless2"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified_beam_search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for pruned_transducer_stateless2
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
+ path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
new file mode 100644
index 000000000..e05b04bee
--- /dev/null
+++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
@@ -0,0 +1,155 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-librispeech-2022-04-19
+# stateless transducer + torchaudio rnn-t loss
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_librispeech_2022_04_19:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
+
+ - name: Display decoding results
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./transducer_stateless2/exp
+
+ cd transducer_stateless2
+ echo "results for transducer_stateless2"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified_beam_search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for transducer_stateless2
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
+ path: egs/librispeech/ASR/transducer_stateless2/exp/
diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml
index 1758a3521..f4c6bf507 100644
--- a/.github/workflows/run-pretrained-conformer-ctc.yml
+++ b/.github/workflows/run-pretrained-conformer-ctc.yml
@@ -31,9 +31,6 @@ jobs:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
- torch: ["1.10.0"]
- torchaudio: ["0.10.0"]
- k2-version: ["1.9.dev20211101"]
fail-fast: false
@@ -43,67 +40,37 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip pytest
- # numpy 1.20.x does not support python 3.6
- pip install numpy==1.19
- pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
- python3 -m pip install git+https://github.com/lhotse-speech/lhotse
- python3 -m pip install kaldifeat
- # We are in ./icefall and there is a file: requirements.txt in it
- pip install -r requirements.txt
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
- - name: Install graphviz
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
- python3 -m pip install -qq graphviz
- sudo apt-get -qq install graphviz
+ .github/scripts/install-kaldifeat.sh
- - name: Download pre-trained model
+ - name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
- cd egs/librispeech/ASR
- mkdir tmp
- cd tmp
- git lfs install
- git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500
- cd ..
- tree tmp
- soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
- ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac
-
- - name: Run CTC decoding
- shell: bash
- run: |
- export PYTHONPATH=$PWD:PYTHONPATH
- cd egs/librispeech/ASR
- ./conformer_ctc/pretrained.py \
- --num-classes 500 \
- --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \
- --method ctc-decoding \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
-
- - name: Run HLG decoding
- shell: bash
- run: |
export PYTHONPATH=$PWD:$PYTHONPATH
- cd egs/librispeech/ASR
- ./conformer_ctc/pretrained.py \
- --num-classes 500 \
- --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \
- --words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \
- --HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \
- ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+ .github/scripts/run-pre-trained-conformer-ctc.sh
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
new file mode 100644
index 000000000..348a68095
--- /dev/null
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
@@ -0,0 +1,154 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
+
+ - name: Display decoding results for transducer_stateless_multi_datasets
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./transducer_stateless_multi_datasets/exp
+
+ cd transducer_stateless_multi_datasets
+ echo "results for transducer_stateless_multi_datasets"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for transducer_stateless_multi_datasets
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
+ path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
new file mode 100644
index 000000000..d1369c2b1
--- /dev/null
+++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
@@ -0,0 +1,154 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-960h
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+jobs:
+ run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
+
+ - name: Display decoding results for transducer_stateless_multi_datasets
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./transducer_stateless_multi_datasets/exp
+
+ cd transducer_stateless_multi_datasets
+ echo "results for transducer_stateless_multi_datasets"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for transducer_stateless_multi_datasets
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
+ path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
new file mode 100644
index 000000000..9d095a0aa
--- /dev/null
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
@@ -0,0 +1,76 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-pre-trained-trandsucer-stateless-modified-2-aishell
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+jobs:
+ run_pre_trained_transducer_stateless_modified_2_aishell:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ run: |
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+ .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
new file mode 100644
index 000000000..868fe6fbe
--- /dev/null
+++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
@@ -0,0 +1,76 @@
+# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
+
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+name: run-pre-trained-trandsucer-stateless-modified-aishell
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ types: [labeled]
+
+jobs:
+ run_pre_trained_transducer_stateless_modified_aishell:
+ if: github.event.label.name == 'ready' || github.event_name == 'push'
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-18.04]
+ python-version: [3.7, 3.8, 3.9]
+
+ fail-fast: false
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ run: |
+ sudo apt-get -qq install git-lfs tree sox
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+ .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
index 3bbd4c49b..78c1ca059 100644
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ b/.github/workflows/run-pretrained-transducer-stateless.yml
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-name: run-pre-trained-trandsucer-stateless
+name: run-pre-trained-transducer-stateless
on:
push:
@@ -23,17 +23,23 @@ on:
pull_request:
types: [labeled]
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
jobs:
run_pre_trained_transducer_stateless:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
+ if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
- torch: ["1.10.0"]
- torchaudio: ["0.10.0"]
- k2-version: ["1.9.dev20211101"]
fail-fast: false
@@ -43,66 +49,106 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip pytest
- # numpy 1.20.x does not support python 3.6
- pip install numpy==1.19
- pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
- python3 -m pip install git+https://github.com/lhotse-speech/lhotse
- python3 -m pip install kaldifeat
- # We are in ./icefall and there is a file: requirements.txt in it
- pip install -r requirements.txt
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
- - name: Install graphviz
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
- python3 -m pip install -qq graphviz
- sudo apt-get -qq install graphviz
+ .github/scripts/install-kaldifeat.sh
- - name: Download pre-trained model
+ - name: Cache LibriSpeech test-clean and test-other datasets
+ id: libri-test-clean-and-test-other-data
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/download
+ key: cache-libri-test-clean-and-test-other
+
+ - name: Download LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
shell: bash
run: |
+ .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
+
+ - name: Prepare manifests for LibriSpeech test-clean and test-other
+ shell: bash
+ run: |
+ .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
+
+ - name: Cache LibriSpeech test-clean and test-other fbank features
+ id: libri-test-clean-and-test-other-fbank
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/fbank-libri
+ key: cache-libri-fbank-test-clean-and-test-other-v2
+
+ - name: Compute fbank for LibriSpeech test-clean and test-other
+ if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
+
+ - name: Inference with pre-trained model
+ shell: bash
+ env:
+ GITHUB_EVENT_NAME: ${{ github.event_name }}
+ GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ run: |
+ mkdir -p egs/librispeech/ASR/data
+ ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
+ ls -lh egs/librispeech/ASR/data/*
+
sudo apt-get -qq install git-lfs tree sox
- cd egs/librispeech/ASR
- mkdir tmp
- cd tmp
- git lfs install
- git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
- cd ..
- tree tmp
- soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
- ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
-
- - name: Run greedy search decoding
- shell: bash
- run: |
- export PYTHONPATH=$PWD:PYTHONPATH
- cd egs/librispeech/ASR
- ./transducer_stateless/pretrained.py \
- --method greedy_search \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
-
- - name: Run beam search decoding
- shell: bash
- run: |
export PYTHONPATH=$PWD:$PYTHONPATH
- cd egs/librispeech/ASR
- ./transducer_stateless/pretrained.py \
- --method beam_search \
- --beam-size 4 \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/run-pre-trained-transducer-stateless.sh
+
+ - name: Display decoding results for transducer_stateless
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ shell: bash
+ run: |
+ cd egs/librispeech/ASR/
+ tree ./transducer_stateless/exp
+
+ cd transducer_stateless
+ echo "results for transducer_stateless"
+ echo "===greedy search==="
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===fast_beam_search==="
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ echo "===modified beam search==="
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
+ find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
+
+ - name: Upload decoding results for transducer_stateless
+ uses: actions/upload-artifact@v2
+ if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
+ with:
+ name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
+ path: egs/librispeech/ASR/transducer_stateless/exp/
diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml
index f0ebddba3..959e57278 100644
--- a/.github/workflows/run-pretrained-transducer.yml
+++ b/.github/workflows/run-pretrained-transducer.yml
@@ -31,9 +31,6 @@ jobs:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
- torch: ["1.10.0"]
- torchaudio: ["0.10.0"]
- k2-version: ["1.9.dev20211101"]
fail-fast: false
@@ -43,67 +40,37 @@ jobs:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip pytest
- # numpy 1.20.x does not support python 3.6
- pip install numpy==1.19
- pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
- python3 -m pip install git+https://github.com/lhotse-speech/lhotse
- python3 -m pip install kaldifeat
- # We are in ./icefall and there is a file: requirements.txt in it
- pip install -r requirements.txt
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}
- - name: Install graphviz
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
- python3 -m pip install -qq graphviz
- sudo apt-get -qq install graphviz
+ make -j2 _kaldifeat
- - name: Download pre-trained model
+ - name: Inference with pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
- cd egs/librispeech/ASR
- mkdir tmp
- cd tmp
- git lfs install
- git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
-
- cd ..
- tree tmp
- soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
- ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
-
- - name: Run greedy search decoding
- shell: bash
- run: |
- export PYTHONPATH=$PWD:PYTHONPATH
- cd egs/librispeech/ASR
- ./transducer/pretrained.py \
- --method greedy_search \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
-
- - name: Run beam search decoding
- shell: bash
- run: |
export PYTHONPATH=$PWD:$PYTHONPATH
- cd egs/librispeech/ASR
- ./transducer/pretrained.py \
- --method beam_search \
- --beam-size 4 \
- --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
- --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
- ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+ .github/scripts/run-pre-trained-transducer.sh
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index 98b2e4ebd..ce77c47df 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -33,9 +33,6 @@ jobs:
# TODO: enable macOS for CPU testing
os: [ubuntu-18.04]
python-version: [3.8]
- torch: ["1.10.0"]
- torchaudio: ["0.10.0"]
- k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
@@ -43,10 +40,17 @@ jobs:
with:
fetch-depth: 0
+ - name: Install graphviz
+ shell: bash
+ run: |
+ sudo apt-get -qq install graphviz
+
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
- name: Install libnsdfile and libsox
if: startsWith(matrix.os, 'ubuntu')
@@ -57,13 +61,9 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install -U pip
- pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
- python3 -m pip install git+https://github.com/lhotse-speech/lhotse
-
- # We are in ./icefall and there is a file: requirements.txt in it
- python3 -m pip install -r requirements.txt
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
- name: Run yesno recipe
shell: bash
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index 2a743705a..239a0280c 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -29,7 +29,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [ubuntu-18.04, macos-10.15]
+ os: [ubuntu-18.04, macos-latest]
python-version: [3.7, 3.9]
fail-fast: false
@@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2
+ python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
+ # See https://github.com/psf/black/issues/2964
+ # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8
shell: bash
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f2c63a3b8..1583926ec 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -33,13 +33,13 @@ jobs:
# disable macOS test for now.
os: [ubuntu-18.04]
python-version: [3.7, 3.8]
- torch: ["1.8.0", "1.10.0"]
- torchaudio: ["0.8.0", "0.10.0"]
- k2-version: ["1.9.dev20211101"]
+ torch: ["1.8.0", "1.11.0"]
+ torchaudio: ["0.8.0", "0.11.0"]
+ k2-version: ["1.15.1.dev20220427"]
exclude:
- torch: "1.8.0"
- torchaudio: "0.10.0"
- - torch: "1.10.0"
+ torchaudio: "0.11.0"
+ - torch: "1.11.0"
torchaudio: "0.8.0"
fail-fast: false
@@ -67,7 +67,7 @@ jobs:
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
+ if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
else
pip install torchaudio==${{ matrix.torchaudio }}
@@ -76,6 +76,9 @@ jobs:
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf
+
pip install -r requirements.txt
- name: Install graphviz
@@ -103,11 +106,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
+ cd ../pruned_transducer_stateless
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless2
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless3
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless4
+ pytest -v -s
+
+ cd ../transducer_stateless
+ pytest -v -s
+
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
- cd ../transducer_stateless
+ cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm
@@ -128,11 +146,26 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
+ cd ../pruned_transducer_stateless
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless2
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless3
+ pytest -v -s
+
+ cd ../pruned_transducer_stateless4
+ pytest -v -s
+
+ cd ../transducer_stateless
+ pytest -v -s
+
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer
pytest -v -s
- cd ../transducer_stateless
+ cd ../transducer_stateless2
pytest -v -s
cd ../transducer_lstm
diff --git a/.gitignore b/.gitignore
index 870d3cea3..1dbf8f395 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,8 @@ exp
exp*/
*.pt
download
+dask-worker-space
+log
*.bak
*-bak
*bak.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b59784dbf..446ba0fe7 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,6 +4,8 @@ repos:
hooks:
- id: black
args: [--line-length=80]
+ additional_dependencies: ['click==8.0.1']
+ exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
diff --git a/README.md b/README.md
index f7aed9dc3..fcba0723b 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,18 @@
+## Introduction
+
+icefall contains ASR recipes for various datasets
+using .
+
+You can use to deploy models
+trained with icefall.
+
+You can try pre-trained models from within your browser without the need
+to download or install anything by visiting
+See for more details.
+
## Installation
Please refer to
@@ -12,12 +24,19 @@ for installation.
Please refer to
for more information.
-We provide four recipes at present:
+We provide the following recipes:
- [yesno][yesno]
- [LibriSpeech][librispeech]
- [Aishell][aishell]
- [TIMIT][timit]
+ - [TED-LIUM3][tedlium3]
+ - [GigaSpeech][gigaspeech]
+ - [Aidatatang_200zh][aidatatang_200zh]
+ - [WenetSpeech][wenetspeech]
+ - [Alimeeting][alimeeting]
+ - [Aishell4][aishell4]
+ - [TAL_CSASR][tal_csasr]
### yesno
@@ -34,6 +53,9 @@ We do provide a Colab notebook for this recipe.
### LibriSpeech
+Please see
+for the **latest** results.
+
We provide 4 models for this recipe:
- [conformer CTC model][LibriSpeech_conformer_ctc]
@@ -80,16 +102,30 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing)
+We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+
+
+#### k2 pruned RNN-T
+
+| | test-clean | test-other |
+|-----|------------|------------|
+| WER | 2.57 | 5.95 |
+
+#### k2 pruned RNN-T + GigaSpeech
+
+| | test-clean | test-other |
+|-----|------------|------------|
+| WER | 2.00 | 4.63 |
+
### Aishell
@@ -105,7 +141,7 @@ The best CER we currently have is:
| CER | 4.26 |
-We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1WnG17io5HEZ0Gn_cnh_VzK5QYOoiiklC?usp=sharing)
+We provide a Colab notebook to run a pre-trained conformer CTC model: [
#### Transducer Stateless Model
@@ -113,7 +149,7 @@ The best CER we currently have is:
| | test |
|-----|------|
-| CER | 5.7 |
+| CER | 4.68 |
We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
@@ -153,6 +189,130 @@ The PER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing)
+### TED-LIUM3
+
+We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless].
+
+#### Transducer Stateless: Conformer encoder + Embedding decoder
+
+The best WER using modified beam search with beam size 4 is:
+
+| | dev | test |
+|-----|-------|--------|
+| WER | 6.91 | 6.33 |
+
+Note: No auxiliary losses are used in the training and no LMs are used in the decoding.
+
+We provide a Colab notebook to run a pre-trained Transducer Stateless model: [](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing)
+
+#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+
+The best WER using modified beam search with beam size 4 is:
+
+| | dev | test |
+|-----|-------|--------|
+| WER | 6.77 | 6.14 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
+
+### GigaSpeech
+
+We provide two models for this recipe: [Conformer CTC model][GigaSpeech_conformer_ctc]
+and [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
+
+#### Conformer CTC
+
+| | Dev | Test |
+|-----|-------|-------|
+| WER | 10.47 | 10.58 |
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+
+| | Dev | Test |
+|----------------------|-------|-------|
+| greedy search | 10.51 | 10.73 |
+| fast beam search | 10.50 | 10.69 |
+| modified beam search | 10.40 | 10.51 |
+
+### Aidatatang_200zh
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+
+| | Dev | Test |
+|----------------------|-------|-------|
+| greedy search | 5.53 | 6.59 |
+| fast beam search | 5.30 | 6.34 |
+| modified beam search | 5.27 | 6.33 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
+
+### WenetSpeech
+
+We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
+
+#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
+
+| | Dev | Test-Net | Test-Meeting |
+|----------------------|-------|----------|--------------|
+| greedy search | 7.80 | 8.75 | 13.49 |
+| fast beam search | 7.94 | 8.74 | 13.80 |
+| modified beam search | 7.76 | 8.71 | 13.41 |
+
+#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
+**Streaming**:
+| | Dev | Test-Net | Test-Meeting |
+|----------------------|-------|----------|--------------|
+| greedy_search | 8.78 | 10.12 | 16.16 |
+| modified_beam_search | 8.53| 9.95 | 15.81 |
+| fast_beam_search| 9.01 | 10.47 | 16.28 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
+
+### Alimeeting
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset)
+
+| | Eval | Test-Net |
+|----------------------|--------|----------|
+| greedy search | 31.77 | 34.66 |
+| fast beam search | 31.39 | 33.02 |
+| modified beam search | 30.38 | 34.25 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
+
+### Aishell4
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
+
+The best CER(%) results:
+| | test |
+|----------------------|--------|
+| greedy search | 29.89 |
+| fast beam search | 28.91 |
+| modified beam search | 29.08 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+
+### TAL_CSASR
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+
+The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
+|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
+|--|--|--|--|--|--|--|
+|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
+|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
+|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
## Deployment with C++
@@ -175,8 +335,25 @@ Please see: [ or your custom
# ones.
extensions = [
+ "sphinx.ext.todo",
"sphinx_rtd_theme",
+ "sphinxcontrib.youtube",
]
# Add any paths that contain templates here, relative to this directory.
@@ -74,3 +76,5 @@ html_context = {
"github_version": "master",
"conf_py_path": "/icefall/docs/source/",
}
+
+todo_include_todos = True
diff --git a/docs/source/huggingface/index.rst b/docs/source/huggingface/index.rst
new file mode 100644
index 000000000..bd731793b
--- /dev/null
+++ b/docs/source/huggingface/index.rst
@@ -0,0 +1,13 @@
+Huggingface
+===========
+
+This section describes how to find pre-trained models.
+It also demonstrates how to try them from within your browser
+without installing anything by using
+`Huggingface spaces `_.
+
+.. toctree::
+ :maxdepth: 2
+
+ pretrained-models
+ spaces
diff --git a/docs/source/huggingface/pic/hugging-face-sherpa-2.png b/docs/source/huggingface/pic/hugging-face-sherpa-2.png
new file mode 100644
index 000000000..3b47bd51b
Binary files /dev/null and b/docs/source/huggingface/pic/hugging-face-sherpa-2.png differ
diff --git a/docs/source/huggingface/pic/hugging-face-sherpa-3.png b/docs/source/huggingface/pic/hugging-face-sherpa-3.png
new file mode 100644
index 000000000..1d7a2d316
Binary files /dev/null and b/docs/source/huggingface/pic/hugging-face-sherpa-3.png differ
diff --git a/docs/source/huggingface/pic/hugging-face-sherpa.png b/docs/source/huggingface/pic/hugging-face-sherpa.png
new file mode 100644
index 000000000..dea0b1d46
Binary files /dev/null and b/docs/source/huggingface/pic/hugging-face-sherpa.png differ
diff --git a/docs/source/huggingface/pretrained-models.rst b/docs/source/huggingface/pretrained-models.rst
new file mode 100644
index 000000000..8ae22f76f
--- /dev/null
+++ b/docs/source/huggingface/pretrained-models.rst
@@ -0,0 +1,17 @@
+Pre-trained models
+==================
+
+We have uploaded pre-trained models for all recipes in ``icefall``
+to ``_.
+
+You can find them by visiting the following link:
+
+``_.
+
+You can also find links of pre-trained models for a specific recipe
+by looking at the corresponding ``RESULTS.md``. For instance:
+
+ - ``_
+ - ``_
+ - ``_
+ - ``_
diff --git a/docs/source/huggingface/spaces.rst b/docs/source/huggingface/spaces.rst
new file mode 100644
index 000000000..e718c3731
--- /dev/null
+++ b/docs/source/huggingface/spaces.rst
@@ -0,0 +1,65 @@
+Huggingface spaces
+==================
+
+We have integrated the server framework
+`sherpa `_
+with `Huggingface spaces `_
+so that you can try pre-trained models from within your browser
+without the need to download or install anything.
+
+All you need is a browser, which can be run on Windows, macOS, Linux, or even on your
+iPad and your phone.
+
+Start your browser and visit the following address:
+
+``_
+
+and you will see a page like the following screenshot:
+
+.. image:: ./pic/hugging-face-sherpa.png
+ :alt: screenshot of ``_
+ :target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
+
+You can:
+
+ 1. Select a language for recognition. Currently, we provide pre-trained models
+ from ``icefall`` for the following languages: ``Chinese``, ``English``, and
+ ``Chinese+English``.
+ 2. After selecting the target language, you can select a pre-trained model
+ corresponding to the language.
+ 3. Select the decoding method. Currently, it provides ``greedy search``
+ and ``modified_beam_search``.
+ 4. If you selected ``modified_beam_search``, you can choose the number of
+ active paths during the search.
+ 5. Either upload a file or record your speech for recognition.
+ 6. Click the button ``Submit for recognition``.
+ 7. Wait for a moment and you will get the recognition results.
+
+The following screenshot shows an example when selecting ``Chinese+English``:
+
+.. image:: ./pic/hugging-face-sherpa-3.png
+ :alt: screenshot of ``_
+ :target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
+
+
+In the bottom part of the page, you can find a table of examples. You can click
+one of them and then click ``Submit for recognition``.
+
+.. image:: ./pic/hugging-face-sherpa-2.png
+ :alt: screenshot of ``_
+ :target: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
+
+YouTube Video
+-------------
+
+We provide the following YouTube video demonstrating how to use
+``_.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ElN3r9dkKE4
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b06047a89..29491e3dc 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -23,3 +23,4 @@ speech recognition recipes using `k2 `_.
installation/index
recipes/index
contributing/index
+ huggingface/index
diff --git a/docs/source/installation/images/README.md b/docs/source/installation/images/README.md
new file mode 100644
index 000000000..97c1e993c
--- /dev/null
+++ b/docs/source/installation/images/README.md
@@ -0,0 +1,4 @@
+
+# Introduction
+
+ is used to generate files in this directory.
diff --git a/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
new file mode 100644
index 000000000..534b2e534
--- /dev/null
+++ b/docs/source/installation/images/k2-gt-v1.9-blueviolet.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/installation/images/k2-v1.9-blueviolet.svg b/docs/source/installation/images/k2-v1.9-blueviolet.svg
deleted file mode 100644
index 5a207b370..000000000
--- a/docs/source/installation/images/k2-v1.9-blueviolet.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg b/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
deleted file mode 100644
index befc1e19e..000000000
--- a/docs/source/installation/images/python-3.6_3.7_3.8_3.9-blue.svg
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/docs/source/installation/images/python-gt-v3.6-blue.svg b/docs/source/installation/images/python-gt-v3.6-blue.svg
new file mode 100644
index 000000000..4254dc58a
--- /dev/null
+++ b/docs/source/installation/images/python-gt-v3.6-blue.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg b/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
deleted file mode 100644
index 496e5a9ef..000000000
--- a/docs/source/installation/images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/docs/source/installation/images/torch-gt-v1.6.0-green.svg b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
new file mode 100644
index 000000000..d3ece9a17
--- /dev/null
+++ b/docs/source/installation/images/torch-gt-v1.6.0-green.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index 0f846c77c..c4474c3d9 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -15,21 +15,33 @@ Installation
.. |device| image:: ./images/device-CPU_CUDA-orange.svg
:alt: Supported devices
-.. |python_versions| image:: ./images/python-3.6_3.7_3.8_3.9-blue.svg
+.. |python_versions| image:: ./images/python-gt-v3.6-blue.svg
:alt: Supported python versions
-.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
+.. |torch_versions| image:: ./images/torch-gt-v1.6.0-green.svg
:alt: Supported PyTorch versions
-.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
+.. |k2_versions| image:: ./images/k2-gt-v1.9-blueviolet.svg
:alt: Supported k2 versions
``icefall`` depends on `k2 `_ and
`lhotse `_.
-We recommend you to install ``k2`` first, as ``k2`` is bound to
-a specific version of PyTorch after compilation. Install ``k2`` also
-installs its dependency PyTorch, which can be reused by ``lhotse``.
+We recommend you to use the following steps to install the dependencies.
+
+- (0) Install PyTorch and torchaudio
+- (1) Install k2
+- (2) Install lhotse
+
+.. caution::
+
+ Installation order matters.
+
+(0) Install PyTorch and torchaudio
+----------------------------------
+
+Please refer ``_ to install PyTorch
+and torchaudio.
(1) Install k2
@@ -54,14 +66,15 @@ to install ``k2``.
Please refer to ``_
to install ``lhotse``.
-.. HINT::
- Install ``lhotse`` also installs its dependency `torchaudio `_.
+.. hint::
-.. CAUTION::
+ We strongly recommend you to use::
+
+ pip install git+https://github.com/lhotse-speech/lhotse
+
+ to install the latest version of lhotse.
- If you have installed ``torchaudio``, please consider uninstalling it before
- installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall
--------------------
@@ -461,3 +474,19 @@ The decoding log is:
**Congratulations!** You have successfully setup the environment and have run the first recipe in ``icefall``.
Have fun with ``icefall``!
+
+YouTube Video
+-------------
+
+We provide the following YouTube video showing how to install ``icefall``.
+It also shows how to debug various problems that you may encounter while
+using ``icefall``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: LVmrBD0tLfE
diff --git a/docs/source/recipes/aishell.rst b/docs/source/recipes/aishell.rst
deleted file mode 100644
index 71ccaa1fc..000000000
--- a/docs/source/recipes/aishell.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-Aishell
-=======
-
-We provide the following models for the Aishell dataset:
-
-.. toctree::
- :maxdepth: 2
-
- aishell/conformer_ctc
- aishell/tdnn_lstm_ctc
diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst
index 2dcf0c728..75a2a8eca 100644
--- a/docs/source/recipes/aishell/conformer_ctc.rst
+++ b/docs/source/recipes/aishell/conformer_ctc.rst
@@ -1,4 +1,4 @@
-Confromer CTC
+Conformer CTC
=============
This tutorial shows you how to run a conformer ctc model
diff --git a/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png b/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png
new file mode 100644
index 000000000..6c84b28f2
Binary files /dev/null and b/docs/source/recipes/aishell/images/aishell-transducer_stateless_modified-tensorboard-log.png differ
diff --git a/docs/source/recipes/aishell/index.rst b/docs/source/recipes/aishell/index.rst
new file mode 100644
index 000000000..d072d6e9c
--- /dev/null
+++ b/docs/source/recipes/aishell/index.rst
@@ -0,0 +1,22 @@
+aishell
+=======
+
+Aishell is an open-source Chinese Mandarin speech corpus published by Beijing
+Shell Shell Technology Co.,Ltd.
+
+400 people from different accent areas in China are invited to participate in
+the recording, which is conducted in a quiet indoor environment using high
+fidelity microphone and downsampled to 16kHz. The manual transcription accuracy
+is above 95%, through professional speech annotation and strict quality
+inspection. The data is free for academic use. We hope to provide moderate
+amount of data for new researchers in the field of speech recognition.
+
+It can be downloaded from ``_
+
+.. toctree::
+ :maxdepth: 1
+
+ tdnn_lstm_ctc
+ conformer_ctc
+ stateless_transducer
+
diff --git a/docs/source/recipes/aishell/stateless_transducer.rst b/docs/source/recipes/aishell/stateless_transducer.rst
new file mode 100644
index 000000000..e8137b8c1
--- /dev/null
+++ b/docs/source/recipes/aishell/stateless_transducer.rst
@@ -0,0 +1,714 @@
+Stateless Transducer
+====================
+
+This tutorial shows you how to do transducer training in ``icefall``.
+
+.. HINT::
+
+ Instead of using RNN-T or RNN transducer, we only use transducer
+ here. As you will see, there are no RNNs in the model.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe.
+
+In this tutorial, you will learn:
+
+ - (1) What does the transducer model look like
+ - (2) How to prepare data for training and decoding
+ - (3) How to start the training, either with a single GPU or with multiple GPUs
+ - (4) How to do decoding after training, with greedy search, beam search and, **modified beam search**
+ - (5) How to use a pre-trained model provided by us to transcribe sound files
+
+
+The Model
+---------
+
+The transducer model consists of 3 parts:
+
+- **Encoder**: It is a conformer encoder with the following parameters
+
+ - Number of heads: 8
+ - Attention dim: 512
+ - Number of layers: 12
+ - Feedforward dim: 2048
+
+- **Decoder**: We use a stateless model consisting of:
+
+ - An embedding layer with embedding dim 512
+ - A Conv1d layer with a default kernel size 2 (i.e. it sees 2
+ symbols of left-context by default)
+
+- **Joiner**: It consists of a ``nn.tanh()`` and a ``nn.Linear()``.
+
+.. Caution::
+
+ The decoder is stateless and very simple. It is borrowed from
+ ``_
+ (Rnn-Transducer with Stateless Prediction Network)
+
+ We make one modification to it: Place a Conv1d layer right after
+ the embedding layer.
+
+When using Chinese characters as modelling unit, whose vocabulary size
+is 4336 in this specific dataset,
+the number of parameters of the model is ``87939824``, i.e., about ``88 M``.
+
+The Loss
+--------
+
+We are using ``_
+to compute the transducer loss, which removes extra paddings
+in loss computation to save memory.
+
+.. Hint::
+
+ ``optimized_transducer`` implements the technqiues proposed
+ in `Improving RNN Transducer Modeling for End-to-End Speech Recognition `_ to save memory.
+
+ Furthermore, it supports ``modified transducer``, limiting the maximum
+ number of symbols that can be emitted per frame to 1, which simplifies
+ the decoding process significantly. Also, the experiment results
+ show that it does not degrade the performance.
+
+ See ``_
+ for what exactly modified transducer is.
+
+ ``_ shows that
+ in the unpruned case ``optimized_transducer`` has the advantage about minimizing
+ memory usage.
+
+.. todo::
+
+ Add tutorial about ``pruned_transducer_stateless`` that uses k2
+ pruned transducer loss.
+
+.. hint::
+
+ You can use::
+
+ pip install optimized_transducer
+
+ to install ``optimized_transducer``. Refer to
+ ``_ for other
+ alternatives.
+
+Data Preparation
+----------------
+
+To prepare the data for training, please use the following commands:
+
+.. code-block:: bash
+
+ cd egs/aishell/ASR
+ ./prepare.sh --stop-stage 4
+ ./prepare.sh --stage 6 --stop-stage 6
+
+.. note::
+
+ You can use ``./prepare.sh``, though it will generate FSTs that
+ are not used in transducer training.
+
+When you finish running the script, you will get the following two folders:
+
+ - ``data/fbank``: It saves the pre-computed features
+ - ``data/lang_char``: It contains tokens that will be used in the training
+
+Training
+--------
+
+.. code-block:: bash
+
+ cd egs/aishell/ASR
+ ./transducer_stateless_modified/train.py --help
+
+shows you the training options that can be passed from the commandline.
+The following options are used quite often:
+
+ - ``--exp-dir``
+
+ The experiment folder to save logs and model checkpoints,
+ defaults to ``./transducer_stateless_modified/exp``.
+
+ - ``--num-epochs``
+
+ It is the number of epochs to train. For instance,
+ ``./transducer_stateless_modified/train.py --num-epochs 30`` trains for 30
+ epochs and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
+ in the folder set by ``--exp-dir``.
+
+ - ``--start-epoch``
+
+ It's used to resume training.
+ ``./transducer_stateless_modified/train.py --start-epoch 10`` loads the
+ checkpoint from ``exp_dir/epoch-9.pt`` and starts
+ training from epoch 10, based on the state from epoch 9.
+
+ - ``--world-size``
+
+ It is used for single-machine multi-GPU DDP training.
+
+ - (a) If it is 1, then no DDP training is used.
+
+ - (b) If it is 2, then GPU 0 and GPU 1 are used for DDP training.
+
+ The following shows some use cases with it.
+
+ **Use case 1**: You have 4 GPUs, but you only want to use GPU 0 and
+ GPU 2 for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,2"
+ $ ./transducer_stateless_modified/train.py --world-size 2
+
+ **Use case 2**: You have 4 GPUs and you want to use all of them
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/train.py --world-size 4
+
+ **Use case 3**: You have 4 GPUs but you only want to use GPU 3
+ for training. You can do the following:
+
+ .. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ export CUDA_VISIBLE_DEVICES="3"
+ $ ./transducer_stateless_modified/train.py --world-size 1
+
+ .. CAUTION::
+
+ Only single-machine multi-GPU DDP training is implemented at present.
+ There is an on-going PR ``_
+ that adds support for multi-machine multi-GPU DDP training.
+
+ - ``--max-duration``
+
+ It specifies the number of seconds over all utterances in a
+ batch **before padding**.
+ If you encounter CUDA OOM, please reduce it. For instance, if
+ your are using V100 NVIDIA GPU with 32 GB RAM, we recommend you
+ to set it to ``300`` when the vocabulary size is 500.
+
+ .. HINT::
+
+ Due to padding, the number of seconds of all utterances in a
+ batch will usually be larger than ``--max-duration``.
+
+ A larger value for ``--max-duration`` may cause OOM during training,
+ while a smaller value may increase the training time. You have to
+ tune it.
+
+ - ``--lr-factor``
+
+ It controls the learning rate. If you use a single GPU for training, you
+ may want to use a small value for it. If you use multiple GPUs for training,
+ you may increase it.
+
+ - ``--context-size``
+
+ It specifies the kernel size in the decoder. The default value 2 means it
+ functions as a tri-gram LM.
+
+ - ``--modified-transducer-prob``
+
+ It specifies the probability to use modified transducer loss.
+ If it is 0, then no modified transducer is used; if it is 1,
+ then it uses modified transducer loss for all batches. If it is
+ ``p``, it applies modified transducer with probability ``p``.
+
+There are some training options, e.g.,
+number of warmup steps,
+that are not passed from the commandline.
+They are pre-configured by the function ``get_params()`` in
+`transducer_stateless_modified/train.py `_
+
+If you need to change them, please modify ``./transducer_stateless_modified/train.py`` directly.
+
+.. CAUTION::
+
+ The training set is perturbed by speed with two factors: 0.9 and 1.1.
+ Each epoch actually processes ``3x150 == 450`` hours of data.
+
+Training logs
+~~~~~~~~~~~~~
+
+Training logs and checkpoints are saved in the folder set by ``--exp-dir``
+(defaults to ``transducer_stateless_modified/exp``). You will find the following files in that directory:
+
+ - ``epoch-0.pt``, ``epoch-1.pt``, ...
+
+ These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``.
+ To resume training from some checkpoint, say ``epoch-10.pt``, you can use:
+
+ .. code-block:: bash
+
+ $ ./transducer_stateless_modified/train.py --start-epoch 11
+
+ - ``tensorboard/``
+
+ This folder contains TensorBoard logs. Training loss, validation loss, learning
+ rate, etc, are recorded in these logs. You can visualize them by:
+
+ .. code-block:: bash
+
+ $ cd transducer_stateless_modified/exp/tensorboard
+ $ tensorboard dev upload --logdir . --name "Aishell transducer training with icefall" --description "Training modified transducer, see https://github.com/k2-fsa/icefall/pull/219"
+
+ It will print something like below:
+
+ .. code-block::
+
+ TensorFlow installation not found - running with reduced feature set.
+ Upload started and will continue reading any new data as it's added to the logdir.
+
+ To stop uploading, press Ctrl-C.
+
+ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/laGZ6HrcQxOigbFD5E0Y3Q/
+
+ [2022-03-03T14:29:45] Started scanning logdir.
+ [2022-03-03T14:29:48] Total uploaded: 8477 scalars, 0 tensors, 0 binary objects
+ Listening for new data in logdir...
+
+ Note there is a `URL `_ in the
+ above output, click it and you will see the following screenshot:
+
+ .. figure:: images/aishell-transducer_stateless_modified-tensorboard-log.png
+ :width: 600
+ :alt: TensorBoard screenshot
+ :align: center
+ :target: https://tensorboard.dev/experiment/laGZ6HrcQxOigbFD5E0Y3Q
+
+ TensorBoard screenshot.
+
+ - ``log/log-train-xxxx``
+
+ It is the detailed training log in text format, same as the one
+ you saw printed to the console during training.
+
+Usage examples
+~~~~~~~~~~~~~~
+
+The following shows typical use cases:
+
+**Case 1**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/train.py --max-duration 250
+
+It uses ``--max-duration`` of 250 to avoid OOM.
+
+
+**Case 2**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ export CUDA_VISIBLE_DEVICES="0,3"
+ $ ./transducer_stateless_modified/train.py --world-size 2
+
+It uses GPU 0 and GPU 3 for DDP training.
+
+**Case 3**
+^^^^^^^^^^
+
+.. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/train.py --num-epochs 10 --start-epoch 3
+
+It loads checkpoint ``./transducer_stateless_modified/exp/epoch-2.pt`` and starts
+training from epoch 3. Also, it trains for 10 epochs.
+
+Decoding
+--------
+
+The decoding part uses checkpoints saved by the training part, so you have
+to run the training part first.
+
+.. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/decode.py --help
+
+shows the options for decoding.
+
+The commonly used options are:
+
+ - ``--method``
+
+ This specifies the decoding method. Currently, it supports:
+
+ - **greedy_search**. You can provide the commandline option ``--max-sym-per-frame``
+ to limit the maximum number of symbols that can be emitted per frame.
+
+ - **beam_search**. You can provide the commandline option ``--beam-size``.
+
+ - **modified_beam_search**. You can also provide the commandline option ``--beam-size``.
+ To use this method, we assume that you have trained your model with modified transducer,
+ i.e., used the option ``--modified-transducer-prob`` in the training.
+
+ The following command uses greedy search for decoding
+
+ .. code-block::
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/decode.py \
+ --epoch 64 \
+ --avg 33 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 100 \
+ --decoding-method greedy_search \
+ --max-sym-per-frame 1
+
+ The following command uses beam search for decoding
+
+ .. code-block::
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/decode.py \
+ --epoch 64 \
+ --avg 33 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 100 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+ The following command uses ``modified`` beam search for decoding
+
+ .. code-block::
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/decode.py \
+ --epoch 64 \
+ --avg 33 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+ - ``--max-duration``
+
+ It has the same meaning as the one used in training. A larger
+ value may cause OOM.
+
+ - ``--epoch``
+
+ It specifies the checkpoint from which epoch that should be used for decoding.
+
+ - ``--avg``
+
+ It specifies the number of models to average. For instance, if it is 3 and if
+ ``--epoch=10``, then it averages the checkpoints ``epoch-8.pt``, ``epoch-9.pt``,
+ and ``epoch-10.pt`` and the averaged checkpoint is used for decoding.
+
+After decoding, you can find the decoding logs and results in `exp_dir/log/`, e.g.,
+``exp_dir/log/greedy_search``.
+
+Pre-trained Model
+-----------------
+
+We have uploaded a pre-trained model to
+``_
+
+We describe how to use the pre-trained model to transcribe a sound file or
+multiple sound files in the following.
+
+Install kaldifeat
+~~~~~~~~~~~~~~~~~
+
+`kaldifeat `_ is used to
+extract features for a single sound file or multiple sound files
+at the same time.
+
+Please refer to ``_ for installation.
+
+Download the pre-trained model
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The following commands describe how to download the pre-trained model:
+
+.. code-block::
+
+ $ cd egs/aishell/ASR
+ $ mkdir tmp
+ $ cd tmp
+ $ git lfs install
+ $ git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
+
+
+.. CAUTION::
+
+ You have to use ``git lfs`` to download the pre-trained model.
+
+After downloading, you will have the following files:
+
+.. code-block:: bash
+
+ $ cd egs/aishell/ASR
+ $ tree tmp/icefall-aishell-transducer-stateless-modified-2022-03-01
+
+
+.. code-block:: bash
+
+ tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/
+ |-- README.md
+ |-- data
+ | `-- lang_char
+ | |-- L.pt
+ | |-- lexicon.txt
+ | |-- tokens.txt
+ | `-- words.txt
+ |-- exp
+ | `-- pretrained.pt
+ |-- log
+ | |-- errs-test-beam_4-epoch-64-avg-33-beam-4.txt
+ | |-- errs-test-greedy_search-epoch-64-avg-33-context-2-max-sym-per-frame-1.txt
+ | |-- log-decode-epoch-64-avg-33-beam-4-2022-03-02-12-05-03
+ | |-- log-decode-epoch-64-avg-33-context-2-max-sym-per-frame-1-2022-02-28-18-13-07
+ | |-- recogs-test-beam_4-epoch-64-avg-33-beam-4.txt
+ | `-- recogs-test-greedy_search-epoch-64-avg-33-context-2-max-sym-per-frame-1.txt
+ `-- test_wavs
+ |-- BAC009S0764W0121.wav
+ |-- BAC009S0764W0122.wav
+ |-- BAC009S0764W0123.wav
+ `-- transcript.txt
+
+ 5 directories, 16 files
+
+
+**File descriptions**:
+
+ - ``data/lang_char``
+
+ It contains language related files. You can find the vocabulary size in ``tokens.txt``.
+
+ - ``exp/pretrained.pt``
+
+ It contains pre-trained model parameters, obtained by averaging
+ checkpoints from ``epoch-32.pt`` to ``epoch-64.pt``.
+ Note: We have removed optimizer ``state_dict`` to reduce file size.
+
+ - ``log``
+
+ It contains decoding logs and decoded results.
+
+ - ``test_wavs``
+
+ It contains some test sound files from Aishell ``test`` dataset.
+
+The information of the test sound files is listed below:
+
+.. code-block:: bash
+
+ $ soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav
+
+ Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.20 = 67263 samples ~ 315.295 CDDA sectors
+ File Size : 135k
+ Bit Rate : 256k
+ Sample Encoding: 16-bit Signed Integer PCM
+
+
+ Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.12 = 65840 samples ~ 308.625 CDDA sectors
+ File Size : 132k
+ Bit Rate : 256k
+ Sample Encoding: 16-bit Signed Integer PCM
+
+
+ Input File : 'tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'
+ Channels : 1
+ Sample Rate : 16000
+ Precision : 16-bit
+ Duration : 00:00:04.00 = 64000 samples ~ 300 CDDA sectors
+ File Size : 128k
+ Bit Rate : 256k
+ Sample Encoding: 16-bit Signed Integer PCM
+
+ Total Duration of 3 files: 00:00:12.32
+
+Usage
+~~~~~
+
+.. code-block::
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/pretrained.py --help
+
+displays the help information.
+
+It supports three decoding methods:
+
+ - greedy search
+ - beam search
+ - modified beam search
+
+.. note::
+
+ In modified beam search, it limits the maximum number of symbols that can be
+ emitted per frame to 1. To use this method, you have to ensure that your model
+ has been trained with the option ``--modified-transducer-prob``. Otherwise,
+ it may give you poor results.
+
+Greedy search
+^^^^^^^^^^^^^
+
+The command to run greedy search is given below:
+
+.. code-block:: bash
+
+
+ $ cd egs/aishell/ASR
+ $ ./transducer_stateless_modified/pretrained.py \
+ --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
+ --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
+ --method greedy_search \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
+
+The output is as follows:
+
+.. code-block::
+
+ 2022-03-03 15:35:26,531 INFO [pretrained.py:239] device: cuda:0
+ 2022-03-03 15:35:26,994 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
+ 2022-03-03 15:35:27,027 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'greedy_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
+ 2022-03-03 15:35:27,027 INFO [pretrained.py:248] About to create model
+ 2022-03-03 15:35:36,878 INFO [pretrained.py:257] Constructing Fbank computer
+ 2022-03-03 15:35:36,880 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
+ 2022-03-03 15:35:36,891 INFO [pretrained.py:273] Decoding started
+ /ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
+ lengths = ((x_lens - 1) // 2 - 1) // 2
+ 2022-03-03 15:35:37,163 INFO [pretrained.py:320]
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
+ 甚 至 出 现 交 易 几 乎 停 滞 的 情 况
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
+ 一 二 线 城 市 虽 然 也 处 于 调 整 中
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
+ 但 因 为 聚 集 了 过 多 公 共 资 源
+
+ 2022-03-03 15:35:37,163 INFO [pretrained.py:322] Decoding Done
+
+Beam search
+^^^^^^^^^^^
+
+The command to run beam search is given below:
+
+.. code-block:: bash
+
+
+ $ cd egs/aishell/ASR
+
+ $ ./transducer_stateless_modified/pretrained.py \
+ --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
+ --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
+ --method beam_search \
+ --beam-size 4 \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
+
+The output is as follows:
+
+.. code-block::
+
+ 2022-03-03 15:39:09,285 INFO [pretrained.py:239] device: cuda:0
+ 2022-03-03 15:39:09,708 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
+ 2022-03-03 15:39:09,759 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'beam_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
+ 2022-03-03 15:39:09,760 INFO [pretrained.py:248] About to create model
+ 2022-03-03 15:39:18,919 INFO [pretrained.py:257] Constructing Fbank computer
+ 2022-03-03 15:39:18,922 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
+ 2022-03-03 15:39:18,929 INFO [pretrained.py:273] Decoding started
+ /ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
+ lengths = ((x_lens - 1) // 2 - 1) // 2
+ 2022-03-03 15:39:21,046 INFO [pretrained.py:320]
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
+ 甚 至 出 现 交 易 几 乎 停 滞 的 情 况
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
+ 一 二 线 城 市 虽 然 也 处 于 调 整 中
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
+ 但 因 为 聚 集 了 过 多 公 共 资 源
+
+ 2022-03-03 15:39:21,047 INFO [pretrained.py:322] Decoding Done
+
+Modified Beam search
+^^^^^^^^^^^^^^^^^^^^
+
+The command to run modified beam search is given below:
+
+.. code-block:: bash
+
+
+ $ cd egs/aishell/ASR
+
+ $ ./transducer_stateless_modified/pretrained.py \
+ --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \
+ --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav
+
+The output is as follows:
+
+.. code-block::
+
+ 2022-03-03 15:41:23,319 INFO [pretrained.py:239] device: cuda:0
+ 2022-03-03 15:41:23,798 INFO [lexicon.py:176] Loading pre-compiled tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char/Linv.pt
+ 2022-03-03 15:41:23,831 INFO [pretrained.py:246] {'feature_dim': 80, 'encoder_out_dim': 512, 'subsampling_factor': 4, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 12, 'vgg_frontend': False, 'env_info': {'k2-version': '1.13', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': 'f4fefe4882bc0ae59af951da3f47335d5495ef71', 'k2-git-date': 'Thu Feb 10 15:16:02 2022', 'lhotse-version': '1.0.0.dev+missing.version.file', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '50d2281-clean', 'icefall-git-date': 'Wed Mar 2 16:02:38 2022', 'icefall-path': '/ceph-fj/fangjun/open-source-2/icefall-aishell', 'k2-path': '/ceph-fj/fangjun/open-source-2/k2-multi-datasets/k2/python/k2/__init__.py', 'lhotse-path': '/ceph-fj/fangjun/open-source-2/lhotse-aishell/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-2-0815224919-75d558775b-mmnv8', 'IP address': '10.177.72.138'}, 'sample_rate': 16000, 'checkpoint': './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt', 'lang_dir': PosixPath('tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char'), 'method': 'modified_beam_search', 'sound_files': ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav'], 'beam_size': 4, 'context_size': 2, 'max_sym_per_frame': 3, 'blank_id': 0, 'vocab_size': 4336}
+ 2022-03-03 15:41:23,831 INFO [pretrained.py:248] About to create model
+ 2022-03-03 15:41:32,214 INFO [pretrained.py:257] Constructing Fbank computer
+ 2022-03-03 15:41:32,215 INFO [pretrained.py:267] Reading sound files: ['./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav', './tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav']
+ 2022-03-03 15:41:32,220 INFO [pretrained.py:273] Decoding started
+ /ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/conformer.py:113: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
+ lengths = ((x_lens - 1) // 2 - 1) // 2
+ /ceph-fj/fangjun/open-source-2/icefall-aishell/egs/aishell/ASR/transducer_stateless_modified/beam_search.py:402: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
+ topk_hyp_indexes = topk_indexes // logits.size(-1)
+ 2022-03-03 15:41:32,583 INFO [pretrained.py:320]
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav:
+ 甚 至 出 现 交 易 几 乎 停 滞 的 情 况
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav:
+ 一 二 线 城 市 虽 然 也 处 于 调 整 中
+
+ ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav:
+ 但 因 为 聚 集 了 过 多 公 共 资 源
+
+ 2022-03-03 15:41:32,583 INFO [pretrained.py:322] Decoding Done
+
+Colab notebook
+--------------
+
+We provide a colab notebook for this recipe showing how to use a pre-trained model to
+transcribe sound files.
+
+|aishell asr stateless modified transducer colab notebook|
+
+.. |aishell asr stateless modified transducer colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg
+ :target: https://colab.research.google.com/drive/12jpTxJB44vzwtcmJl2DTdznW0OawPb9H?usp=sharing
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 78e9ea569..9d1d83d29 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -10,12 +10,10 @@ We may add recipes for other tasks as well in the future.
.. Other recipes are listed in a alphabetical order.
.. toctree::
- :maxdepth: 3
+ :maxdepth: 2
+ :caption: Table of Contents
- yesno
-
- librispeech
-
- aishell
-
- timit
+ aishell/index
+ librispeech/index
+ timit/index
+ yesno/index
diff --git a/docs/source/recipes/librispeech.rst b/docs/source/recipes/librispeech.rst
deleted file mode 100644
index 946b23407..000000000
--- a/docs/source/recipes/librispeech.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-LibriSpeech
-===========
-
-We provide the following models for the LibriSpeech dataset:
-
-.. toctree::
- :maxdepth: 2
-
- librispeech/tdnn_lstm_ctc
- librispeech/conformer_ctc
diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst
index 5acc4092b..4656acfd6 100644
--- a/docs/source/recipes/librispeech/conformer_ctc.rst
+++ b/docs/source/recipes/librispeech/conformer_ctc.rst
@@ -70,6 +70,17 @@ To run stage 2 to stage 5, use:
All generated files by ``./prepare.sh``, e.g., features, lexicon, etc,
are saved in ``./data`` directory.
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
+
Training
--------
diff --git a/docs/source/recipes/librispeech/index.rst b/docs/source/recipes/librispeech/index.rst
new file mode 100644
index 000000000..5fa08ab6b
--- /dev/null
+++ b/docs/source/recipes/librispeech/index.rst
@@ -0,0 +1,8 @@
+LibriSpeech
+===========
+
+.. toctree::
+ :maxdepth: 1
+
+ tdnn_lstm_ctc
+ conformer_ctc
diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
index 848026802..ca477fbaa 100644
--- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
+++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst
@@ -45,6 +45,16 @@ To run stage 2 to stage 5, use:
$ ./prepare.sh --stage 2 --stop-stage 5
+We provide the following YouTube video showing how to run ``./prepare.sh``.
+
+.. note::
+
+ To get the latest news of `next-gen Kaldi `_, please subscribe
+ the following YouTube channel by `Nadira Povey `_:
+
+ ``_
+
+.. youtube:: ofEIoJL-mGM
Training
--------
diff --git a/docs/source/recipes/timit.rst b/docs/source/recipes/timit.rst
deleted file mode 100644
index b630e2ce4..000000000
--- a/docs/source/recipes/timit.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-TIMIT
-===========
-
-We provide the following models for the TIMIT dataset:
-
-.. toctree::
- :maxdepth: 2
-
- timit/tdnn_lstm_ctc
- timit/tdnn_ligru_ctc
\ No newline at end of file
diff --git a/docs/source/recipes/timit/index.rst b/docs/source/recipes/timit/index.rst
new file mode 100644
index 000000000..17f40cdb7
--- /dev/null
+++ b/docs/source/recipes/timit/index.rst
@@ -0,0 +1,9 @@
+TIMIT
+=====
+
+.. toctree::
+ :maxdepth: 1
+
+ tdnn_ligru_ctc
+ tdnn_lstm_ctc
+
diff --git a/docs/source/recipes/timit/tdnn_ligru_ctc.rst b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
index 30877505f..186420ee7 100644
--- a/docs/source/recipes/timit/tdnn_ligru_ctc.rst
+++ b/docs/source/recipes/timit/tdnn_ligru_ctc.rst
@@ -1,5 +1,5 @@
TDNN-LiGRU-CTC
-=============
+==============
This tutorial shows you how to run a TDNN-LiGRU-CTC model with the `TIMIT `_ dataset.
diff --git a/docs/source/recipes/images/yesno-tdnn-tensorboard-log.png b/docs/source/recipes/yesno/images/tdnn-tensorboard-log.png
similarity index 100%
rename from docs/source/recipes/images/yesno-tdnn-tensorboard-log.png
rename to docs/source/recipes/yesno/images/tdnn-tensorboard-log.png
diff --git a/docs/source/recipes/yesno/index.rst b/docs/source/recipes/yesno/index.rst
new file mode 100644
index 000000000..d68523a97
--- /dev/null
+++ b/docs/source/recipes/yesno/index.rst
@@ -0,0 +1,7 @@
+YesNo
+=====
+
+.. toctree::
+ :maxdepth: 1
+
+ tdnn
diff --git a/docs/source/recipes/yesno.rst b/docs/source/recipes/yesno/tdnn.rst
similarity index 99%
rename from docs/source/recipes/yesno.rst
rename to docs/source/recipes/yesno/tdnn.rst
index cb425ad1d..e8b748e6b 100644
--- a/docs/source/recipes/yesno.rst
+++ b/docs/source/recipes/yesno/tdnn.rst
@@ -1,5 +1,5 @@
-yesno
-=====
+TDNN-CTC
+========
This page shows you how to run the `yesno `_ recipe. It contains:
@@ -145,7 +145,7 @@ In ``tdnn/exp``, you will find the following files:
Note there is a URL in the above output, click it and you will see
the following screenshot:
- .. figure:: images/yesno-tdnn-tensorboard-log.png
+ .. figure:: images/tdnn-tensorboard-log.png
:width: 600
:alt: TensorBoard screenshot
:align: center
diff --git a/egs/aidatatang_200zh/ASR/README.md b/egs/aidatatang_200zh/ASR/README.md
new file mode 100644
index 000000000..b85895a09
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/README.md
@@ -0,0 +1,38 @@
+Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/375
+# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall.
+The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2.
+## Training procedure
+The main repositories are list below, we will update the training and decoding scripts with the update of version.
+k2: https://github.com/k2-fsa/k2
+icefall: https://github.com/k2-fsa/icefall
+lhotse: https://github.com/lhotse-speech/lhotse
+* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall.
+* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above.
+```
+git clone https://github.com/k2-fsa/icefall
+cd icefall
+```
+* Preparing data.
+```
+cd egs/aidatatang_200zh/ASR
+bash ./prepare.sh
+```
+* Training
+```
+export CUDA_VISIBLE_DEVICES="0,1"
+./pruned_transducer_stateless2/train.py \
+ --world-size 2 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 250
+```
+## Evaluation results
+The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29.
+The WERs are
+| | dev | test | comment |
+|------------------------------------|------------|------------|------------------------------------------|
+| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
+| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
+| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|
diff --git a/egs/aidatatang_200zh/ASR/RESULTS.md b/egs/aidatatang_200zh/ASR/RESULTS.md
new file mode 100644
index 000000000..5b82fb61f
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/RESULTS.md
@@ -0,0 +1,72 @@
+## Results
+
+### Aidatatang_200zh Char training results (Pruned Transducer Stateless2)
+
+#### 2022-05-16
+
+Using the codes from this PR https://github.com/k2-fsa/icefall/pull/375.
+
+The WERs are
+
+| | dev | test | comment |
+|------------------------------------|------------|------------|------------------------------------------|
+| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
+| modified beam search (beam size 4) | 5.27 | 6.33 | --epoch 29, --avg 19, --max-duration 100 |
+| fast beam search (set as default) | 5.30 | 6.34 | --epoch 29, --avg 19, --max-duration 1500|
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 2 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 250 \
+ --save-every-n 1000
+
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/xS7kgYf2RwyDpQAOdS8rAA/#scalars
+
+The decoding command is:
+```
+epoch=29
+avg=19
+
+## greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 100
+
+## modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+## fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+A pre-trained model and decoding logs can be found at
diff --git a/egs/aidatatang_200zh/ASR/local/__init__.py b/egs/aidatatang_200zh/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
new file mode 100755
index 000000000..9850cf251
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aidatatang_200zh dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train",
+ "dev",
+ "test",
+ )
+ prefix = "aidatatang"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+
+ for sup in m["supervisions"]:
+ sup.custom = {"origin": "aidatatang_200zh"}
+
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py b/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py
new file mode 100644
index 000000000..d66e5cfca
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,96 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+See the function `remove_short_and_long_utt()`
+in ../../../librispeech/ASR/transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ paths = [
+ "./data/fbank/aidatatang_cuts_train.jsonl.gz",
+ "./data/fbank/aidatatang_cuts_dev.jsonl.gz",
+ "./data/fbank/aidatatang_cuts_test.jsonl.gz",
+ ]
+
+ for path in paths:
+ print(f"Starting display the statistics for {path}")
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Starting display the statistics for ./data/fbank/aidatatang_cuts_train.jsonl.gz
+Cuts count: 494715
+Total duration (hours): 422.6
+Speech duration (hours): 422.6 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.1
+std 1.2
+min 1.0
+25% 2.3
+50% 2.7
+75% 3.5
+99% 7.2
+99.5% 8.0
+99.9% 9.5
+max 18.1
+Starting display the statistics for ./data/fbank/aidatatang_cuts_dev.jsonl.gz
+Cuts count: 24216
+Total duration (hours): 20.2
+Speech duration (hours): 20.2 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.0
+std 1.0
+min 1.2
+25% 2.3
+50% 2.7
+75% 3.4
+99% 6.7
+99.5% 7.3
+99.9% 8.8
+max 11.3
+Starting display the statistics for ./data/fbank/aidatatang_cuts_test.jsonl.gz
+Cuts count: 48144
+Total duration (hours): 40.2
+Speech duration (hours): 40.2 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.0
+std 1.1
+min 0.9
+25% 2.3
+50% 2.6
+75% 3.4
+99% 6.9
+99.5% 7.5
+99.9% 9.0
+max 21.8
+"""
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_char.py b/egs/aidatatang_200zh/ASR/local/prepare_char.py
new file mode 100755
index 000000000..d9e47d17a
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/prepare_char.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/text,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import re
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [
+ token2id[i] if i in token2id else token2id[""] for i in pieces
+ ]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
+ """Check if all the given tokens are in token symbol table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that contains all the valid tokens.
+ tokens:
+ A list of tokens.
+ Returns:
+ Return True if there is any token not in the token_sym_table,
+ otherwise False.
+ """
+ for tok in tokens:
+ if tok not in token_sym_table:
+ return True
+ return False
+
+
+def generate_lexicon(
+ token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
+ """Generate a lexicon from a word list and token_sym_table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that mapping token to token ids.
+ words:
+ A list of strings representing words.
+ Returns:
+ Return a dict whose keys are words and values are the corresponding
+ tokens.
+ """
+ lexicon = []
+ for word in words:
+ chars = list(word.strip(" \t"))
+ if contain_oov(token_sym_table, chars):
+ continue
+ lexicon.append((word, chars))
+
+ # The OOV word is
+ lexicon.append(("", [""]))
+ return lexicon
+
+
+def generate_tokens(text_file: str) -> Dict[str, int]:
+ """Generate tokens from the given text file.
+
+ Args:
+ text_file:
+ A file that contains text lines to generate tokens.
+ Returns:
+ Return a dict whose keys are tokens and values are token ids ranged
+ from 0 to len(keys) - 1.
+ """
+ tokens: Dict[str, int] = dict()
+ tokens[""] = 0
+ tokens[""] = 1
+ tokens[""] = 2
+ whitespace = re.compile(r"([ \t\r\n]+)")
+ with open(text_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = re.sub(whitespace, "", line)
+ chars = list(line)
+ for char in chars:
+ if char not in tokens:
+ tokens[char] = len(tokens)
+ return tokens
+
+
+def main():
+ lang_dir = Path("data/lang_char")
+ text_file = lang_dir / "text"
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = ["", "!SIL", "", "", "#0", "", ""]
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ token_sym_table = generate_tokens(text_file)
+
+ lexicon = generate_lexicon(token_sym_table, words)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
new file mode 100755
index 000000000..e5ae89ec4
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py
@@ -0,0 +1,390 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+ )
+ return parser.parse_args()
+
+
+def main():
+ out_dir = Path(get_args().lang_dir)
+ lexicon_filename = out_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / "tokens.txt", token2id)
+ write_mapping(out_dir / "words.txt", word2id)
+ write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/local/prepare_words.py b/egs/aidatatang_200zh/ASR/local/prepare_words.py
new file mode 100755
index 000000000..65aca2983
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/prepare_words.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input words.txt without ids:
+ - words_no_ids.txt
+and generates the new words.txt with related ids.
+ - words.txt
+"""
+
+
+import argparse
+import logging
+
+from tqdm import tqdm
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Prepare words.txt",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ default="data/lang_char/words_no_ids.txt",
+ type=str,
+ help="the words file without ids for WenetSpeech",
+ )
+ parser.add_argument(
+ "--output-file",
+ default="data/lang_char/words.txt",
+ type=str,
+ help="the words file with ids for WenetSpeech",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ add_words = [" 0", "!SIL 1", " 2", " 3"]
+ new_lines.extend(add_words)
+
+ logging.info("Starting reading the input file")
+ for i in tqdm(range(len(lines))):
+ x = lines[i]
+ idx = 4 + i
+ new_line = str(x.strip("\n")) + " " + str(idx)
+ new_lines.append(new_line)
+
+ logging.info("Starting writing the words.txt")
+ f_out = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_out.write(line)
+ f_out.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
new file mode 100755
index 000000000..d4cf62bba
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/test_prepare_lang.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+import os
+import tempfile
+
+import k2
+from prepare_lang import (
+ add_disambig_symbols,
+ generate_id_map,
+ get_phones,
+ get_words,
+ lexicon_to_fst,
+ read_lexicon,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def generate_lexicon_file() -> str:
+ fd, filename = tempfile.mkstemp()
+ os.close(fd)
+ s = """
+ !SIL SIL
+ SPN
+ SPN
+ f f
+ a a
+ foo f o o
+ bar b a r
+ bark b a r k
+ food f o o d
+ food2 f o o d
+ fo f o
+ """.strip()
+ with open(filename, "w") as f:
+ f.write(s)
+ return filename
+
+
+def test_read_lexicon(filename: str):
+ lexicon = read_lexicon(filename)
+ phones = get_phones(lexicon)
+ words = get_words(lexicon)
+ print(lexicon)
+ print(phones)
+ print(words)
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+ print(lexicon_disambig)
+ print("max disambig:", f"#{max_disambig}")
+
+ phones = ["", "SIL", "SPN"] + phones
+ for i in range(max_disambig + 1):
+ phones.append(f"#{i}")
+ words = [""] + words
+
+ phone2id = generate_id_map(phones)
+ word2id = generate_id_map(words)
+
+ print(phone2id)
+ print(word2id)
+
+ write_mapping("phones.txt", phone2id)
+ write_mapping("words.txt", word2id)
+
+ write_lexicon("a.txt", lexicon)
+ write_lexicon("a_disambig.txt", lexicon_disambig)
+
+ fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
+ fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa.draw("L.pdf", title="L")
+
+ fsa_disambig = lexicon_to_fst(
+ lexicon_disambig, phone2id=phone2id, word2id=word2id
+ )
+ fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
+
+
+def main():
+ filename = generate_lexicon_file()
+ test_read_lexicon(filename)
+ os.remove(filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py
new file mode 100755
index 000000000..71be2a613
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/local/text2token.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
+# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import codecs
+import re
+import sys
+from typing import List
+
+from pypinyin import lazy_pinyin, pinyin
+
+is_python2 = sys.version_info[0] == 2
+
+
+def exist_or_not(i, match_pos):
+ start_pos = None
+ end_pos = None
+ for pos in match_pos:
+ if pos[0] <= i < pos[1]:
+ start_pos = pos[0]
+ end_pos = pos[1]
+ break
+
+ return start_pos, end_pos
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert raw text to tokenized text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--nchar",
+ "-n",
+ default=1,
+ type=int,
+ help="number of characters to split, i.e., \
+ aabb -> a a b b with -n 1 and aa bb with -n 2",
+ )
+ parser.add_argument(
+ "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
+ )
+ parser.add_argument(
+ "--space", default="", type=str, help="space symbol"
+ )
+ parser.add_argument(
+ "--non-lang-syms",
+ "-l",
+ default=None,
+ type=str,
+ help="list of non-linguistic symobles, e.g., etc.",
+ )
+ parser.add_argument(
+ "text", type=str, default=False, nargs="?", help="input text"
+ )
+ parser.add_argument(
+ "--trans_type",
+ "-t",
+ type=str,
+ default="char",
+ choices=["char", "pinyin", "lazy_pinyin"],
+ help="""Transcript type. char/pinyin/lazy_pinyin""",
+ )
+ return parser
+
+
+def token2id(
+ texts, token_table, token_type: str = "lazy_pinyin", oov: str = ""
+) -> List[List[int]]:
+ """Convert token to id.
+ Args:
+ texts:
+ The input texts, it refers to the chinese text here.
+ token_table:
+ The token table is built based on "data/lang_xxx/token.txt"
+ token_type:
+ The type of token, such as "pinyin" and "lazy_pinyin".
+ oov:
+ Out of vocabulary token. When a word(token) in the transcript
+ does not exist in the token list, it is replaced with `oov`.
+
+ Returns:
+ The list of ids for the input texts.
+ """
+ if texts is None:
+ raise ValueError("texts can't be None!")
+ else:
+ oov_id = token_table[oov]
+ ids: List[List[int]] = []
+ for text in texts:
+ chars_list = list(str(text))
+ if token_type == "lazy_pinyin":
+ text = lazy_pinyin(chars_list)
+ sub_ids = [
+ token_table[txt] if txt in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ else: # token_type = "pinyin"
+ text = pinyin(chars_list)
+ sub_ids = [
+ token_table[txt[0]] if txt[0] in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ return ids
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ rs = []
+ if args.non_lang_syms is not None:
+ with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
+ nls = [x.rstrip() for x in f.readlines()]
+ rs = [re.compile(re.escape(x)) for x in nls]
+
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(
+ sys.stdin if is_python2 else sys.stdin.buffer
+ )
+
+ sys.stdout = codecs.getwriter("utf-8")(
+ sys.stdout if is_python2 else sys.stdout.buffer
+ )
+ line = f.readline()
+ n = args.nchar
+ while line:
+ x = line.split()
+ print(" ".join(x[: args.skip_ncols]), end=" ")
+ a = " ".join(x[args.skip_ncols :]) # noqa E203
+
+ # get all matched positions
+ match_pos = []
+ for r in rs:
+ i = 0
+ while i >= 0:
+ m = r.search(a, i)
+ if m:
+ match_pos.append([m.start(), m.end()])
+ i = m.end()
+ else:
+ break
+ if len(match_pos) > 0:
+ chars = []
+ i = 0
+ while i < len(a):
+ start_pos, end_pos = exist_or_not(i, match_pos)
+ if start_pos is not None:
+ chars.append(a[start_pos:end_pos])
+ i = end_pos
+ else:
+ chars.append(a[i])
+ i += 1
+ a = chars
+
+ if args.trans_type == "pinyin":
+ a = pinyin(list(str(a)))
+ a = [one[0] for one in a]
+
+ if args.trans_type == "lazy_pinyin":
+ a = lazy_pinyin(list(str(a)))
+
+ a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203
+
+ a_flat = []
+ for z in a:
+ a_flat.append("".join(z))
+
+ a_chars = "".join(a_flat)
+ print(a_chars)
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
new file mode 100755
index 000000000..3da783006
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -0,0 +1,118 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/aidatatang_200zh
+# You can find "corpus" and "transcript" inside it.
+# You can download it at
+# https://openslr.org/62/
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ if [ ! -f $dl_dir/aidatatang_200zh/transcript/aidatatang_200_zh_transcript.txt ]; then
+ lhotse download aidatatang-200zh $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare aidatatang_200zh manifest"
+ # We assume that you have downloaded the aidatatang_200zh corpus
+ # to $dl_dir/aidatatang_200zh
+ if [ ! -f data/manifests/aidatatang_200zh/.manifests.done ]; then
+ mkdir -p data/manifests/aidatatang_200zh
+ lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
+ touch data/manifests/aidatatang_200zh/.manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Process aidatatang_200zh"
+ if [ ! -f data/fbank/aidatatang_200zh/.fbank.done ]; then
+ mkdir -p data/fbank/aidatatang_200zh
+ lhotse prepare aidatatang-200zh $dl_dir data/manifests/aidatatang_200zh
+ touch data/fbank/aidatatang_200zh/.fbank.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for aidatatang_200zh"
+ if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aidatatang_200zh.py
+ touch data/fbank/.aidatatang_200zh.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare char based lang"
+ lang_char_dir=data/lang_char
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
+ | sed -e 's/["text:\t ]*//g' | sed 's/,//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text
+
+ # Prepare words.txt
+ grep "\"text\":" data/manifests/aidatatang_200zh/supervisions_train.json \
+ | sed -e 's/["text:\t]*//g' | sed 's/,//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text_words
+
+ cat $lang_char_dir/text_words | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
+ | uniq > $lang_char_dir/words_no_ids.txt
+
+ if [ ! -f $lang_char_dir/words.txt ]; then
+ ./local/prepare_words.py \
+ --input-file $lang_char_dir/words_no_ids.txt
+ --output-file $lang_char_dir/words.txt
+ fi
+
+ if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+ ./local/prepare_char.py
+ fi
+fi
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/__init__.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
new file mode 100644
index 000000000..6a5b57e24
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -0,0 +1,420 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ load_manifest,
+ load_manifest_lazy,
+ set_caching_enabled,
+)
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+set_caching_enabled(False)
+torch.set_num_threads(1)
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class Aidatatang_200zhAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/dev/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=300,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_dl.sampler.load_state_dict(sampler_state_dict)
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+
+ from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
+
+ dev_iter_dataset = IterableDatasetWrapper(
+ dataset=validate,
+ sampler=valid_sampler,
+ )
+ valid_dl = DataLoader(
+ dev_iter_dataset,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
+
+ test_iter_dataset = IterableDatasetWrapper(
+ dataset=test,
+ sampler=sampler,
+ )
+ test_dl = DataLoader(
+ test_iter_dataset,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aidatatang_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aidatatang_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aidatatang_cuts_test.jsonl.gz"
+ )
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py
new file mode 120000
index 000000000..a65957180
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
new file mode 100755
index 000000000..b78c600c3
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py
@@ -0,0 +1,600 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+When training with the L subset, usage:
+(1) greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 6 \
+ --avg 3 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 100 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 6 \
+ --avg 3 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 6 \
+ --avg 3 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import Aidatatang_200zhAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--batch",
+ type=int,
+ default=None,
+ help="It specifies the batch checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--avg-last-n",
+ type=int,
+ default=0,
+ help="""If positive, --epoch and --avg are ignored and it
+ will use the last n checkpoints exp_dir/checkpoint-xxx.pt
+ where xxx is the number of processed batches while
+ saving that checkpoint.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 100
+ else:
+ log_interval = 50
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list(str(text).replace(" ", "")) for text in texts]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ this_batch.append((ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ Aidatatang_200zhAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg_last_n > 0:
+ filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ elif params.batch is not None:
+ filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints([filenames], device=device))
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # Note: Please use "pip install webdataset==0.1.103"
+ # for installing the webdataset.
+ import glob
+ import os
+
+ from lhotse import CutSet
+ from lhotse.dataset.webdataset import export_to_webdataset
+
+ aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
+
+ dev = "dev"
+ test = "test"
+
+ if not os.path.exists(f"{dev}/shared-0.tar"):
+ os.makedirs(dev)
+ dev_cuts = aidatatang_200zh.valid_cuts()
+ export_to_webdataset(
+ dev_cuts,
+ output_path=f"{dev}/shared-%d.tar",
+ shard_size=300,
+ )
+
+ if not os.path.exists(f"{test}/shared-0.tar"):
+ os.makedirs(test)
+ test_cuts = aidatatang_200zh.test_cuts()
+ export_to_webdataset(
+ test_cuts,
+ output_path=f"{test}/shared-%d.tar",
+ shard_size=300,
+ )
+
+ dev_shards = [
+ str(path)
+ for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+ ]
+ cuts_dev_webdataset = CutSet.from_webdataset(
+ dev_shards,
+ split_by_worker=True,
+ split_by_node=True,
+ shuffle_shards=True,
+ )
+
+ test_shards = [
+ str(path)
+ for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+ ]
+ cuts_test_webdataset = CutSet.from_webdataset(
+ test_shards,
+ split_by_worker=True,
+ split_by_node=True,
+ shuffle_shards=True,
+ )
+
+ dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
+ test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
new file mode 100644
index 000000000..00b54c39f
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -0,0 +1,181 @@
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless2/export.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --epoch 29 \
+ --avg 19
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless2/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aidatatang_200zh/ASR
+ ./pruned_transducer_stateless2/decode.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.eval()
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
new file mode 100644
index 000000000..eb5e6b0d4
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -0,0 +1,347 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method greedy_search \
+ --max-sym-per-frame 1 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
+./pruned_transducer_stateless2/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Path to lang.
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="Used only when --method is beam_search and modified_beam_search ",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ with torch.no_grad():
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lengths
+ )
+
+ hyps = []
+ msg = f"Using {params.decoding_method}"
+ logging.info(msg)
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
new file mode 100644
index 000000000..d46838b68
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py
@@ -0,0 +1,972 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 2 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 250 \
+ --save-every-n 1000
+
+# For mix precision training:
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 2 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 250 \
+ --save-every-n 1000
+ --use-fp16 True
+
+"""
+
+import argparse
+import logging
+import os
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Aidatatang_200zhAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12359,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=20,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+ Explanation of options saved in `params`:
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+ - best_train_epoch: It is the epoch that has the best training loss.
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+ - subsampling_factor: The subsampling factor for the model.
+ - encoder_dim: Hidden dim for multi-head attention model.
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 10,
+ "log_interval": 1,
+ "reset_interval": 200,
+ "valid_interval": 400,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for decoder
+ "decoder_dim": 512,
+ # parameters for joiner
+ "joiner_dim": 512,
+ # parameters for Noam
+ "model_warm_step": 200,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`.
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 0:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+
+ y = graph_compiler.texts_to_ids(texts)
+ if type(y) == list:
+ y = k2.RaggedTensor(y).to(device)
+ else:
+ y = y.to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+ model.device = device
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
+
+ train_cuts = aidatatang_200zh.train_cuts()
+ valid_cuts = aidatatang_200zh.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 10.0 seconds
+ #
+ # Caution: There is a reason to select 10.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 10.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = aidatatang_200zh.valid_dataloaders(valid_cuts)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aidatatang_200zh.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ if not params.print_diagnostics and params.start_batch == 0:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ scheduler.step_epoch(epoch)
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ Aidatatang_200zhAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.lang_dir = Path(args.lang_dir)
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aidatatang_200zh/ASR/shared b/egs/aidatatang_200zh/ASR/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file
diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md
index 3fd177376..75fc6326e 100644
--- a/egs/aishell/ASR/README.md
+++ b/egs/aishell/ASR/README.md
@@ -1,3 +1,23 @@
-Please refer to
+# Introduction
+
+Please refer to
for how to run models in this recipe.
+
+
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
+| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` |
+| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
+| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
+| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
+
+The decoder in `transducer_stateless` is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index dd27e1f35..b420a1982 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -1,12 +1,281 @@
## Results
+### Aishell training result(Stateless Transducer)
+
+#### Pruned transducer stateless 3
+
+See
+
+
+[./pruned_transducer_stateless3](./pruned_transducer_stateless3)
+
+It uses pruned RNN-T.
+
+| | test | dev | comment |
+|------------------------|------|------|---------------------------------------|
+| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
+| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
+| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
+
+Training command is:
+
+```bash
+./prepare.sh
+./prepare_aidatatang_200zh.sh
+
+export CUDA_VISIBLE_DEVICES="4,5,6,7"
+
+./pruned_transducer_stateless3/train.py \
+ --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
+ --world-size 4 \
+ --max-duration 200 \
+ --datatang-prob 0.5 \
+ --start-epoch 1 \
+ --num-epochs 30 \
+ --use-fp16 1 \
+ --num-encoder-layers 12 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --context-size 1 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --master-port 12356
+```
+
+**Caution**: It uses `--context-size=1`.
+
+The tensorboard log is available at
+
+
+The decoding command is:
+
+```bash
+for epoch in 29; do
+ for avg in 5; do
+ for m in greedy_search modified_beam_search fast_beam_search; do
+ ./pruned_transducer_stateless3/decode.py \
+ --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --max-duration 600 \
+ --decoding-method $m \
+ --num-encoder-layers 12 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --context-size 1 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+ done
+ done
+done
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
+to use the pre-trained model for non-streaming ASR. See
+
+
+#### 2022-03-01
+
+[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
+
+It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer)
+for computing RNN-T loss.
+
+Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data.
+
+
+| | test |comment |
+|------------------------|------|----------------------------------------------------------------|
+| greedy search | 4.94 |--epoch 89, --avg 38, --max-duration 100, --max-sym-per-frame 1 |
+| modified beam search | 4.68 |--epoch 89, --avg 38, --max-duration 100 --beam-size 4 |
+
+The training commands are:
+
+```bash
+cd egs/aishell/ASR
+./prepare.sh --stop-stage 6
+./prepare_aidatatang_200zh.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2"
+
+./transducer_stateless_modified-2/train.py \
+ --world-size 3 \
+ --num-epochs 90 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 250 \
+ --lr-factor 2.0 \
+ --context-size 2 \
+ --modified-transducer-prob 0.25 \
+ --datatang-prob 0.2
+```
+
+The tensorboard log is available at
+
+
+The commands for decoding are
+
+```bash
+# greedy search
+for epoch in 89; do
+ for avg in 38; do
+ ./transducer_stateless_modified-2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method greedy_search \
+ --max-sym-per-frame 1
+ done
+done
+
+# modified beam search
+for epoch in 89; do
+ for avg in 38; do
+ ./transducer_stateless_modified-2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+ done
+done
+```
+
+You can find a pre-trained model, decoding logs, and decoding results at
+
+
+#### 2022-03-01
+
+[./transducer_stateless_modified](./transducer_stateless_modified)
+
+Stateless transducer + modified transducer.
+
+| | test |comment |
+|------------------------|------|----------------------------------------------------------------|
+| greedy search | 5.22 |--epoch 64, --avg 33, --max-duration 100, --max-sym-per-frame 1 |
+| modified beam search | 5.02 |--epoch 64, --avg 33, --max-duration 100 --beam-size 4 |
+
+The training commands are:
+
+```bash
+cd egs/aishell/ASR
+./prepare.sh --stop-stage 6
+
+export CUDA_VISIBLE_DEVICES="0,1,2"
+
+./transducer_stateless_modified/train.py \
+ --world-size 3 \
+ --num-epochs 90 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified/exp-4 \
+ --max-duration 250 \
+ --lr-factor 2.0 \
+ --context-size 2 \
+ --modified-transducer-prob 0.25
+```
+
+The tensorboard log is available at
+
+
+The commands for decoding are
+
+```bash
+# greedy search
+for epoch in 64; do
+ for avg in 33; do
+ ./transducer_stateless_modified/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified/exp-4 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method greedy_search \
+ --max-sym-per-frame 1
+ done
+done
+
+# modified beam search
+for epoch in 64; do
+ for avg in 33; do
+ ./transducer_stateless_modified/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified/exp-4 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+ done
+done
+```
+
+You can find a pre-trained model, decoding logs, and decoding results at
+
+
+
+#### 2022-2-19
+(Duo Ma): The tensorboard log for training is available at https://tensorboard.dev/experiment/25PmX3MxSVGTdvIdhOwllw/#scalars
+You can find a pretrained model by visiting https://huggingface.co/shuanguanma/icefall_aishell_transducer_stateless_context_size2_epoch60_2022_2_19
+| | test |comment |
+|---------------------------|------|-----------------------------------------|
+| greedy search | 5.4 |--epoch 59, --avg 10, --max-duration 100 |
+| beam search | 5.05|--epoch 59, --avg 10, --max-duration 100 |
+
+You can use the following commands to reproduce our results:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+python3 ./transducer_stateless/train.py \
+ --world-size 4 \
+ --num-epochs 60 \
+ --start-epoch 0 \
+ --exp-dir exp/transducer_stateless_context_size2 \
+ --max-duration 100 \
+ --lr-factor 2.5 \
+ --context-size 2
+
+lang_dir=data/lang_char
+dir=exp/transducer_stateless_context_size2
+python3 ./transducer_stateless/decode.py \
+ --epoch 59 \
+ --avg 10 \
+ --exp-dir $dir \
+ --lang-dir $lang_dir \
+ --decoding-method greedy_search \
+ --context-size 2 \
+ --max-sym-per-frame 3
+
+lang_dir=data/lang_char
+dir=exp/transducer_stateless_context_size2
+python3 ./transducer_stateless/decode.py \
+ --epoch 59 \
+ --avg 10 \
+ --exp-dir $dir \
+ --lang-dir $lang_dir \
+ --decoding-method beam_search \
+ --context-size 2 \
+ --max-sym-per-frame 3
+```
+
### Aishell training results (Transducer-stateless)
-#### 2021-12-29
-(Pingfeng Luo) : The tensorboard log for training is available at
+#### 2022-02-18
+(Pingfeng Luo) : The tensorboard log for training is available at
+And pretrained model is available at
||test|
|--|--|
-|CER| 5.7% |
+|CER| 5.05% |
You can use the following commands to reproduce our results:
@@ -16,17 +285,17 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
--bucketing-sampler True \
--world-size 8 \
--lang-dir data/lang_char \
- --num-epochs 40 \
+ --num-epochs 60 \
--start-epoch 0 \
- --exp-dir transducer_stateless/exp_char \
- --max-duration 160 \
+ --exp-dir transducer_stateless/exp_rnnt_k2 \
+ --max-duration 80 \
--lr-factor 3
./transducer_stateless/decode.py \
- --epoch 39 \
+ --epoch 59 \
--avg 10 \
--lang-dir data/lang_char \
- --exp-dir transducer_stateless/exp_char \
+ --exp-dir transducer_stateless/exp_rnnt_k2 \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py
index 7bd0f95cf..1e3e7b492 100644
--- a/egs/aishell/ASR/conformer_ctc/conformer.py
+++ b/egs/aishell/ASR/conformer_ctc/conformer.py
@@ -364,7 +364,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
- # Suppose `i` means to the position of query vecotr and `j` means the
+ # Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i None:
- """
- Args:
- ignore_index:
- ignored class id
- label_smoothing:
- smoothing rate (0.0 means the conventional cross entropy loss)
- reduction:
- It has the same meaning as the reduction in
- `torch.nn.CrossEntropyLoss`. It can be one of the following three
- values: (1) "none": No reduction will be applied. (2) "mean": the
- mean of the output is taken. (3) "sum": the output will be summed.
- """
- super().__init__()
- assert 0.0 <= label_smoothing < 1.0
- self.ignore_index = ignore_index
- self.label_smoothing = label_smoothing
- self.reduction = reduction
-
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- """
- Compute loss between x and target.
-
- Args:
- x:
- prediction of dimension
- (batch_size, input_length, number_of_classes).
- target:
- target masked with self.ignore_index of
- dimension (batch_size, input_length).
-
- Returns:
- A scalar tensor containing the loss without normalization.
- """
- assert x.ndim == 3
- assert target.ndim == 2
- assert x.shape[:2] == target.shape
- num_classes = x.size(-1)
- x = x.reshape(-1, num_classes)
- # Now x is of shape (N*T, C)
-
- # We don't want to change target in-place below,
- # so we make a copy of it here
- target = target.clone().reshape(-1)
-
- ignored = target == self.ignore_index
- target[ignored] = 0
-
- true_dist = torch.nn.functional.one_hot(
- target, num_classes=num_classes
- ).to(x)
-
- true_dist = (
- true_dist * (1 - self.label_smoothing)
- + self.label_smoothing / num_classes
- )
- # Set the value of ignored indexes to 0
- true_dist[ignored] = 0
-
- loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
- if self.reduction == "sum":
- return loss.sum()
- elif self.reduction == "mean":
- return loss.sum() / (~ignored).sum()
- else:
- return loss.sum(dim=-1)
diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/aishell/ASR/conformer_ctc/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index a4bc8e3bb..a228cc1fe 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -121,6 +121,13 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
return parser
@@ -188,9 +195,9 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
- "log_interval": 10,
+ "log_interval": 50,
"reset_interval": 200,
- "valid_interval": 3000,
+ "valid_interval": 2000,
# parameters for k2.ctc_loss
"beam_size": 10,
"reduction": "sum",
@@ -555,7 +562,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
- fix_random_seed(42)
+ fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@@ -618,6 +625,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
diff --git a/egs/aishell/ASR/conformer_mmi/conformer.py b/egs/aishell/ASR/conformer_mmi/conformer.py
index 7bd0f95cf..1e3e7b492 100644
--- a/egs/aishell/ASR/conformer_mmi/conformer.py
+++ b/egs/aishell/ASR/conformer_mmi/conformer.py
@@ -364,7 +364,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
- # Suppose `i` means to the position of query vecotr and `j` means the
+ # Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i None:
- """
- Args:
- ignore_index:
- ignored class id
- label_smoothing:
- smoothing rate (0.0 means the conventional cross entropy loss)
- reduction:
- It has the same meaning as the reduction in
- `torch.nn.CrossEntropyLoss`. It can be one of the following three
- values: (1) "none": No reduction will be applied. (2) "mean": the
- mean of the output is taken. (3) "sum": the output will be summed.
- """
- super().__init__()
- assert 0.0 <= label_smoothing < 1.0
- self.ignore_index = ignore_index
- self.label_smoothing = label_smoothing
- self.reduction = reduction
-
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- """
- Compute loss between x and target.
-
- Args:
- x:
- prediction of dimension
- (batch_size, input_length, number_of_classes).
- target:
- target masked with self.ignore_index of
- dimension (batch_size, input_length).
-
- Returns:
- A scalar tensor containing the loss without normalization.
- """
- assert x.ndim == 3
- assert target.ndim == 2
- assert x.shape[:2] == target.shape
- num_classes = x.size(-1)
- x = x.reshape(-1, num_classes)
- # Now x is of shape (N*T, C)
-
- # We don't want to change target in-place below,
- # so we make a copy of it here
- target = target.clone().reshape(-1)
-
- ignored = target == self.ignore_index
- target[ignored] = 0
-
- true_dist = torch.nn.functional.one_hot(
- target, num_classes=num_classes
- ).to(x)
-
- true_dist = (
- true_dist * (1 - self.label_smoothing)
- + self.label_smoothing / num_classes
- )
- # Set the value of ignored indexes to 0
- true_dist[ignored] = 0
-
- loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
- if self.reduction == "sum":
- return loss.sum()
- elif self.reduction == "mean":
- return loss.sum() / (~ignored).sum()
- else:
- return loss.sum(dim=-1)
diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py
new file mode 120000
index 000000000..08734abd7
--- /dev/null
+++ b/egs/aishell/ASR/conformer_mmi/label_smoothing.py
@@ -0,0 +1 @@
+../conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 79c16d1cc..685831d09 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -124,6 +124,13 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
return parser
@@ -546,7 +553,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
- fix_random_seed(42)
+ fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@@ -613,6 +620,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
diff --git a/egs/aishell/ASR/local/compile_hlg.py b/egs/aishell/ASR/local/compile_hlg.py
deleted file mode 100755
index 098d5d6a3..000000000
--- a/egs/aishell/ASR/local/compile_hlg.py
+++ /dev/null
@@ -1,156 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This script takes as input lang_dir and generates HLG from
-
- - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- - L, the lexicon, built from lang_dir/L_disambig.pt
-
- Caution: We use a lexicon that contains disambiguation symbols
-
- - G, the LM, built from data/lm/G_3_gram.fst.txt
-
-The generated HLG is saved in $lang_dir/HLG.pt
-"""
-import argparse
-import logging
-from pathlib import Path
-
-import k2
-import torch
-
-from icefall.lexicon import Lexicon
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir",
- type=str,
- help="""Input and output directory.
- """,
- )
-
- return parser.parse_args()
-
-
-def compile_HLG(lang_dir: str) -> k2.Fsa:
- """
- Args:
- lang_dir:
- The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
-
- Return:
- An FSA representing HLG.
- """
- lexicon = Lexicon(lang_dir)
- max_token_id = max(lexicon.tokens)
- logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
- H = k2.ctc_topo(max_token_id)
- L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
-
- if Path("data/lm/G_3_gram.pt").is_file():
- logging.info("Loading pre-compiled G_3_gram")
- d = torch.load("data/lm/G_3_gram.pt")
- G = k2.Fsa.from_dict(d)
- else:
- logging.info("Loading G_3_gram.fst.txt")
- with open("data/lm/G_3_gram.fst.txt") as f:
- G = k2.Fsa.from_openfst(f.read(), acceptor=False)
- torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
-
- first_token_disambig_id = lexicon.token_table["#0"]
- first_word_disambig_id = lexicon.word_table["#0"]
-
- L = k2.arc_sort(L)
- G = k2.arc_sort(G)
-
- logging.info("Intersecting L and G")
- LG = k2.compose(L, G)
- logging.info(f"LG shape: {LG.shape}")
-
- logging.info("Connecting LG")
- LG = k2.connect(LG)
- logging.info(f"LG shape after k2.connect: {LG.shape}")
-
- logging.info(type(LG.aux_labels))
- logging.info("Determinizing LG")
-
- LG = k2.determinize(LG)
- logging.info(type(LG.aux_labels))
-
- logging.info("Connecting LG after k2.determinize")
- LG = k2.connect(LG)
-
- logging.info("Removing disambiguation symbols on LG")
-
- LG.labels[LG.labels >= first_token_disambig_id] = 0
-
- assert isinstance(LG.aux_labels, k2.RaggedTensor)
- LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
-
- LG = k2.remove_epsilon(LG)
- logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
-
- LG = k2.connect(LG)
- LG.aux_labels = LG.aux_labels.remove_values_eq(0)
-
- logging.info("Arc sorting LG")
- LG = k2.arc_sort(LG)
-
- logging.info("Composing H and LG")
- # CAUTION: The name of the inner_labels is fixed
- # to `tokens`. If you want to change it, please
- # also change other places in icefall that are using
- # it.
- HLG = k2.compose(H, LG, inner_labels="tokens")
-
- logging.info("Connecting LG")
- HLG = k2.connect(HLG)
-
- logging.info("Arc sorting LG")
- HLG = k2.arc_sort(HLG)
- logging.info(f"HLG.shape: {HLG.shape}")
-
- return HLG
-
-
-def main():
- args = get_args()
- lang_dir = Path(args.lang_dir)
-
- if (lang_dir / "HLG.pt").is_file():
- logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
- return
-
- logging.info(f"Processing {lang_dir}")
-
- HLG = compile_HLG(lang_dir)
- logging.info(f"Saving HLG.pt to {lang_dir}")
- torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
-
-
-if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
-
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- main()
diff --git a/egs/aishell/ASR/local/compile_hlg.py b/egs/aishell/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/aishell/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
new file mode 100755
index 000000000..8cdfad71f
--- /dev/null
+++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aidatatang_200zh dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train",
+ "test",
+ "dev",
+ )
+ prefix = "aidatatang"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+
+ for sup in m["supervisions"]:
+ sup.custom = {"origin": "aidatatang_200zh"}
+
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index b3b9e7681..e27e35ec5 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@@ -52,8 +52,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
"dev",
"test",
)
+ prefix = "aishell"
+ suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
- dataset_parts=dataset_parts, output_dir=src_dir
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
)
assert manifests is not None
@@ -61,7 +66,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
- if (output_dir / f"cuts_{partition}.json.gz").is_file():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
@@ -77,13 +82,13 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
- storage_path=f"{output_dir}/feats_{partition}",
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
- storage_type=LilcomHdf5Writer,
+ storage_type=LilcomChunkyWriter,
)
- cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():
diff --git a/egs/aishell/ASR/local/compute_fbank_musan.py b/egs/aishell/ASR/local/compute_fbank_musan.py
deleted file mode 100755
index e79bdafb1..000000000
--- a/egs/aishell/ASR/local/compute_fbank_musan.py
+++ /dev/null
@@ -1,110 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This file computes fbank features of the musan dataset.
-It looks for manifests in the directory data/manifests.
-
-The generated fbank features are saved in data/fbank.
-"""
-
-import argparse
-import logging
-import os
-from pathlib import Path
-
-import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
-from lhotse.recipes.utils import read_manifests_if_cached
-
-from icefall.utils import get_executor
-
-# Torch's multithreaded behavior needs to be disabled or
-# it wastes a lot of CPU and slow things down.
-# Do this outside of main() in case it needs to take effect
-# even when we are not invoking the main (e.g. when spawning subprocesses).
-torch.set_num_threads(1)
-torch.set_num_interop_threads(1)
-
-
-def compute_fbank_musan(num_mel_bins: int = 80):
- src_dir = Path("data/manifests")
- output_dir = Path("data/fbank")
- num_jobs = min(15, os.cpu_count())
-
- dataset_parts = (
- "music",
- "speech",
- "noise",
- )
- manifests = read_manifests_if_cached(
- dataset_parts=dataset_parts, output_dir=src_dir
- )
- assert manifests is not None
-
- musan_cuts_path = output_dir / "cuts_musan.json.gz"
-
- if musan_cuts_path.is_file():
- logging.info(f"{musan_cuts_path} already exists - skipping")
- return
-
- logging.info("Extracting features for Musan")
-
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
-
- with get_executor() as ex: # Initialize the executor only once.
- # create chunks of Musan with duration 5 - 10 seconds
- musan_cuts = (
- CutSet.from_manifests(
- recordings=combine(
- part["recordings"] for part in manifests.values()
- )
- )
- .cut_into_windows(10.0)
- .filter(lambda c: c.duration > 5)
- .compute_and_store_features(
- extractor=extractor,
- storage_path=f"{output_dir}/feats_musan",
- num_jobs=num_jobs if ex is None else 80,
- executor=ex,
- storage_type=LilcomHdf5Writer,
- )
- )
- musan_cuts.to_json(musan_cuts_path)
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--num-mel-bins",
- type=int,
- default=80,
- help="""The number of mel bins for Fbank""",
- )
-
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
-
- logging.basicConfig(format=formatter, level=logging.INFO)
- args = get_args()
- compute_fbank_musan(num_mel_bins=args.num_mel_bins)
diff --git a/egs/aishell/ASR/local/compute_fbank_musan.py b/egs/aishell/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/aishell/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py b/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py
deleted file mode 100755
index 133499c8b..000000000
--- a/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py
+++ /dev/null
@@ -1,107 +0,0 @@
-#!/usr/bin/env python3
-
-# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
-"""
-Convert a transcript file containing words to a corpus file containing tokens
-for LM training with the help of a lexicon.
-
-If the lexicon contains phones, the resulting LM will be a phone LM; If the
-lexicon contains word pieces, the resulting LM will be a word piece LM.
-
-If a word has multiple pronunciations, the one that appears first in the lexicon
-is kept; others are removed.
-
-If the input transcript is:
-
- hello zoo world hello
- world zoo
- foo zoo world hellO
-
-and if the lexicon is
-
- SPN
- hello h e l l o 2
- hello h e l l o
- world w o r l d
- zoo z o o
-
-Then the output is
-
- h e l l o 2 z o o w o r l d h e l l o 2
- w o r l d z o o
- SPN z o o w o r l d SPN
-"""
-
-import argparse
-from pathlib import Path
-from typing import Dict, List
-
-from generate_unique_lexicon import filter_multiple_pronunications
-
-from icefall.lexicon import read_lexicon
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--transcript",
- type=str,
- help="The input transcript file."
- "We assume that the transcript file consists of "
- "lines. Each line consists of space separated words.",
- )
- parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
- parser.add_argument(
- "--oov", type=str, default="", help="The OOV word."
- )
-
- return parser.parse_args()
-
-
-def process_line(
- lexicon: Dict[str, List[str]], line: str, oov_token: str
-) -> None:
- """
- Args:
- lexicon:
- A dict containing pronunciations. Its keys are words and values
- are pronunciations (i.e., tokens).
- line:
- A line of transcript consisting of space(s) separated words.
- oov_token:
- The pronunciation of the oov word if a word in `line` is not present
- in the lexicon.
- Returns:
- Return None.
- """
- s = ""
- words = line.strip().split()
- for i, w in enumerate(words):
- tokens = lexicon.get(w, oov_token)
- s += " ".join(tokens)
- s += " "
- print(s.strip())
-
-
-def main():
- args = get_args()
- assert Path(args.lexicon).is_file()
- assert Path(args.transcript).is_file()
- assert len(args.oov) > 0
-
- # Only the first pronunciation of a word is kept
- lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
-
- lexicon = dict(lexicon)
-
- assert args.oov in lexicon
-
- oov_token = lexicon[args.oov]
-
- with open(args.transcript) as f:
- for line in f:
- process_line(lexicon=lexicon, line=line, oov_token=oov_token)
-
-
-if __name__ == "__main__":
- main()
diff --git a/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py b/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/aishell/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/local/display_manifest_statistics.py b/egs/aishell/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..c478f7331
--- /dev/null
+++ b/egs/aishell/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer_stateless/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ # path = "./data/fbank/aishell_cuts_train.jsonl.gz"
+ # path = "./data/fbank/aishell_cuts_test.jsonl.gz"
+ path = "./data/fbank/aishell_cuts_dev.jsonl.gz"
+ # path = "./data/fbank/aidatatang_cuts_train.jsonl.gz"
+ # path = "./data/fbank/aidatatang_cuts_test.jsonl.gz"
+ # path = "./data/fbank/aidatatang_cuts_dev.jsonl.gz"
+
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+## train (after speed perturb)
+Cuts count: 360294
+Total duration (hours): 455.6
+Speech duration (hours): 455.6 (100.0%)
+***
+Duration statistics (seconds):
+mean 4.6
+std 1.4
+min 1.1
+0.1% 1.8
+0.5% 2.2
+1% 2.3
+5% 2.7
+10% 3.0
+10% 3.0
+25% 3.5
+50% 4.3
+75% 5.4
+90% 6.5
+95% 7.2
+99% 8.8
+99.5% 9.4
+99.9% 10.9
+max 16.1
+
+## test
+Cuts count: 7176
+Total duration (hours): 10.0
+Speech duration (hours): 10.0 (100.0%)
+***
+Duration statistics (seconds):
+mean 5.0
+std 1.6
+min 1.9
+0.1% 2.2
+0.5% 2.4
+1% 2.6
+5% 3.0
+10% 3.2
+10% 3.2
+25% 3.8
+50% 4.7
+75% 5.9
+90% 7.3
+95% 8.2
+99% 9.9
+99.5% 10.7
+99.9% 11.9
+max 14.7
+
+## dev
+Cuts count: 14326
+Total duration (hours): 18.1
+Speech duration (hours): 18.1 (100.0%)
+***
+Duration statistics (seconds):
+mean 4.5
+std 1.3
+min 1.6
+0.1% 2.1
+0.5% 2.3
+1% 2.4
+5% 2.9
+10% 3.1
+10% 3.1
+25% 3.5
+50% 4.3
+75% 5.4
+90% 6.4
+95% 7.0
+99% 8.4
+99.5% 8.9
+99.9% 10.3
+max 12.5
+
+## aidatatang_200zh (train)
+Cuts count: 164905
+Total duration (hours): 139.9
+Speech duration (hours): 139.9 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.1
+std 1.1
+min 1.1
+0.1% 1.5
+0.5% 1.7
+1% 1.8
+5% 2.0
+10% 2.1
+10% 2.1
+25% 2.3
+50% 2.7
+75% 3.4
+90% 4.6
+95% 5.4
+99% 7.1
+99.5% 7.8
+99.9% 9.1
+max 16.3
+
+## aidatatang_200zh (test)
+Cuts count: 48144
+Total duration (hours): 40.2
+Speech duration (hours): 40.2 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.0
+std 1.1
+min 0.9
+0.1% 1.5
+0.5% 1.8
+1% 1.8
+5% 2.0
+10% 2.1
+10% 2.1
+25% 2.3
+50% 2.6
+75% 3.4
+90% 4.4
+95% 5.2
+99% 6.9
+99.5% 7.5
+99.9% 9.0
+max 21.8
+
+## aidatatang_200zh (dev)
+Cuts count: 24216
+Total duration (hours): 20.2
+Speech duration (hours): 20.2 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.0
+std 1.0
+min 1.2
+0.1% 1.6
+0.5% 1.7
+1% 1.8
+5% 2.0
+10% 2.1
+10% 2.1
+25% 2.3
+50% 2.7
+75% 3.4
+90% 4.4
+95% 5.1
+99% 6.7
+99.5% 7.3
+99.9% 8.8
+max 11.3
+"""
diff --git a/egs/aishell/ASR/local/generate_unique_lexicon.py b/egs/aishell/ASR/local/generate_unique_lexicon.py
deleted file mode 100755
index 566c0743d..000000000
--- a/egs/aishell/ASR/local/generate_unique_lexicon.py
+++ /dev/null
@@ -1,100 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This file takes as input a lexicon.txt and output a new lexicon,
-in which each word has a unique pronunciation.
-
-The way to do this is to keep only the first pronunciation of a word
-in lexicon.txt.
-"""
-
-
-import argparse
-import logging
-from pathlib import Path
-from typing import List, Tuple
-
-from icefall.lexicon import read_lexicon, write_lexicon
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lang-dir",
- type=str,
- help="""Input and output directory.
- It should contain a file lexicon.txt.
- This file will generate a new file uniq_lexicon.txt
- in it.
- """,
- )
-
- return parser.parse_args()
-
-
-def filter_multiple_pronunications(
- lexicon: List[Tuple[str, List[str]]]
-) -> List[Tuple[str, List[str]]]:
- """Remove multiple pronunciations of words from a lexicon.
-
- If a word has more than one pronunciation in the lexicon, only
- the first one is kept, while other pronunciations are removed
- from the lexicon.
-
- Args:
- lexicon:
- The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
- where "p1, p2, ..., pn" are the pronunciations of the "word".
- Returns:
- Return a new lexicon where each word has a unique pronunciation.
- """
- seen = set()
- ans = []
-
- for word, tokens in lexicon:
- if word in seen:
- continue
- seen.add(word)
- ans.append((word, tokens))
- return ans
-
-
-def main():
- args = get_args()
- lang_dir = Path(args.lang_dir)
-
- lexicon_filename = lang_dir / "lexicon.txt"
-
- in_lexicon = read_lexicon(lexicon_filename)
-
- out_lexicon = filter_multiple_pronunications(in_lexicon)
-
- write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
-
- logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
- logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
-
-
-if __name__ == "__main__":
- formatter = (
- "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- )
-
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- main()
diff --git a/egs/aishell/ASR/local/generate_unique_lexicon.py b/egs/aishell/ASR/local/generate_unique_lexicon.py
new file mode 120000
index 000000000..c0aea1403
--- /dev/null
+++ b/egs/aishell/ASR/local/generate_unique_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/generate_unique_lexicon.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index a99558395..f86dd8de3 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -18,7 +18,7 @@ stop_stage=10
# This directory contains the language model downloaded from
# https://huggingface.co/pkufool/aishell_lm
#
-# - 3-gram.unpruned.apra
+# - 3-gram.unpruned.arpa
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
@@ -48,8 +48,11 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install`
- [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
- git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
+ git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1)
+
+ if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then
+ git clone https://huggingface.co/pkufool/aishell_lm $dl_dir/lm
+ fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@@ -69,7 +72,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt
# `-- speaker.info
- if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then
+ if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir
fi
@@ -87,28 +90,41 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare aishell manifest"
# We assume that you have downloaded the aishell corpus
# to $dl_dir/aishell
- mkdir -p data/manifests
- lhotse prepare aishell -j $nj $dl_dir/aishell data/manifests
+ if [ ! -f data/manifests/.aishell_manifests.done ]; then
+ mkdir -p data/manifests
+ lhotse prepare aishell $dl_dir/aishell data/manifests
+ touch data/manifests/.aishell_manifests.done
+ fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to data/musan
- mkdir -p data/manifests
- lhotse prepare musan $dl_dir/musan data/manifests
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for aishell"
- mkdir -p data/fbank
- ./local/compute_fbank_aishell.py
+ if [ ! -f data/fbank/.aishell.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aishell.py
+ touch data/fbank/.aishell.done
+ fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
- mkdir -p data/fbank
- ./local/compute_fbank_musan.py
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
fi
lang_phone_dir=data/lang_phone
@@ -134,7 +150,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
find $dl_dir/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
- cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
+ cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
diff --git a/egs/aishell/ASR/prepare_aidatatang_200zh.sh b/egs/aishell/ASR/prepare_aidatatang_200zh.sh
new file mode 100755
index 000000000..f1d4d18a7
--- /dev/null
+++ b/egs/aishell/ASR/prepare_aidatatang_200zh.sh
@@ -0,0 +1,59 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/aidatatang_200zh
+# You can find "corpus" and "transcript" inside it.
+# You can download it at
+# https://openslr.org/62/
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ if [ ! -f $dl_dir/aidatatang_200zh/transcript/aidatatang_200_zh_transcript.txt ]; then
+ lhotse download aidatatang-200zh $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare manifest"
+ # We assume that you have downloaded the aidatatang_200zh corpus
+ # to $dl_dir/aidatatang_200zh
+ if [ ! -f data/manifests/.aidatatang_200zh_manifests.done ]; then
+ mkdir -p data/manifests
+ lhotse prepare aidatatang-200zh $dl_dir data/manifests
+ touch data/manifests/.aidatatang_200zh_manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Process aidatatang_200zh"
+ if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aidatatang_200zh.py
+ touch data/fbank/.aidatatang_200zh_fbank.done
+ fi
+fi
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py b/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py
new file mode 120000
index 000000000..9a799406b
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py
@@ -0,0 +1 @@
+../transducer_stateless_modified-2/aidatatang_200zh.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py b/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py
new file mode 120000
index 000000000..1b5f38a54
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py
@@ -0,0 +1 @@
+../transducer_stateless_modified-2/aishell.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py
new file mode 120000
index 000000000..ae3bdd1e0
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py
@@ -0,0 +1 @@
+../transducer_stateless_modified-2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py b/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
new file mode 100755
index 000000000..6aea306c8
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -0,0 +1,638 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless3/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from aishell import AIShell
+from asr_datamodule import AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless3/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps token ID to a string.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ hyp_tokens = []
+ batch_size = encoder_out.size(0)
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyp_tokens.append(hyp)
+
+ hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps a token ID to a string.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ token_table=token_table,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((list("".join(res[0])), list("".join(res[1]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params.datatang_prob = 0
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+
+ model.to(device)
+ model.eval()
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ asr_datamodule = AsrDataModule(args)
+ aishell = AIShell(manifest_dir=args.manifest_dir)
+ test_cuts = aishell.test_cuts()
+ dev_cuts = aishell.valid_cuts()
+ test_dl = asr_datamodule.test_dataloaders(test_cuts)
+ dev_dl = asr_datamodule.test_dataloaders(dev_cuts)
+
+ test_sets = ["test", "dev"]
+ test_dls = [test_dl, dev_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ token_table=lexicon.token_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
new file mode 120000
index 000000000..bcd4abc2f
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
@@ -0,0 +1 @@
+/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
new file mode 100755
index 000000000..566902a85
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -0,0 +1,278 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless3/export.py \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --jit 0 \
+ --epoch 29 \
+ --avg 5
+
+It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt
+
+To use the generated file with `pruned_transducer_stateless3/decode.py`,
+you can do::
+
+ cd /path/to/exp_dir
+ ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell/ASR
+ ./pruned_transducer_stateless3/decode.py \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=29,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=Path,
+ default=Path("pruned_transducer_stateless3/exp"),
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+ params.datatang_prob = 0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = (
+ params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
+ )
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = (
+ params.exp_dir
+ / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
+ )
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
new file mode 100644
index 000000000..e150e8230
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py
@@ -0,0 +1,236 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+from scaling import ScaledLinear
+
+from icefall.utils import add_sos
+
+
+class Transducer(nn.Module):
+ """It implements https://arxiv.org/pdf/1211.3711.pdf
+ "Sequence Transduction with Recurrent Neural Networks"
+ """
+
+ def __init__(
+ self,
+ encoder: EncoderInterface,
+ decoder: nn.Module,
+ joiner: nn.Module,
+ encoder_dim: int,
+ decoder_dim: int,
+ joiner_dim: int,
+ vocab_size: int,
+ decoder_datatang: Optional[nn.Module] = None,
+ joiner_datatang: Optional[nn.Module] = None,
+ ):
+ """
+ Args:
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+ It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, decoder_dim).
+ It should contain one attribute: `blank_id`.
+ joiner:
+ It has two inputs with shapes: (N, T, encoder_dim) and
+ (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size).
+ Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ encoder_dim:
+ Output dimension of the encoder network.
+ decoder_dim:
+ Output dimension of the decoder network.
+ joiner_dim:
+ Input dimension of the joiner network.
+ vocab_size:
+ Output dimension of the joiner network.
+ decoder_datatang:
+ Optional. The decoder network for the aidatatang_200zh dataset.
+ joiner_datatang:
+ Optional. The joiner network for the aidatatang_200zh dataset.
+ """
+ super().__init__()
+
+ assert isinstance(encoder, EncoderInterface), type(encoder)
+ assert hasattr(decoder, "blank_id")
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joiner = joiner
+
+ self.decoder_datatang = decoder_datatang
+ self.joiner_datatang = joiner_datatang
+
+ self.simple_am_proj = ScaledLinear(
+ encoder_dim, vocab_size, initial_speed=0.5
+ )
+ self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
+
+ if decoder_datatang is not None:
+ self.simple_am_proj_datatang = ScaledLinear(
+ encoder_dim, vocab_size, initial_speed=0.5
+ )
+ self.simple_lm_proj_datatang = ScaledLinear(decoder_dim, vocab_size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ aishell: bool = True,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ warmup: float = 1.0,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ aishell:
+ True to use the decoder and joiner for the aishell dataset.
+ False to use the decoder and joiner for the aidatatang_200zh
+ dataset.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ warmup:
+ A value warmup >= 0 that determines which modules are active, values
+ warmup > 1 "are fully warmed up" and all modules will be active.
+ Returns:
+ Return the transducer loss.
+
+ Note:
+ Regarding am_scale & lm_scale, it will make the loss-function one of
+ the form:
+ lm_scale * lm_probs + am_scale * am_probs +
+ (1-lm_scale-am_scale) * combined_probs
+ """
+ assert x.ndim == 3, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == x_lens.size(0) == y.dim0
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
+ assert torch.all(encoder_out_lens > 0)
+
+ if aishell:
+ decoder = self.decoder
+ simple_lm_proj = self.simple_lm_proj
+ simple_am_proj = self.simple_am_proj
+ joiner = self.joiner
+ else:
+ decoder = self.decoder_datatang
+ simple_lm_proj = self.simple_lm_proj_datatang
+ simple_am_proj = self.simple_am_proj_datatang
+ joiner = self.joiner_datatang
+
+ # Now for the decoder, i.e., the prediction network
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ blank_id = decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ # sos_y_padded: [B, S + 1], start with SOS.
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+ # decoder_out: [B, S + 1, decoder_dim]
+ decoder_out = decoder(sos_y_padded)
+
+ # Note: y does not start with SOS
+ # y_padded : [B, S]
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (x.size(0), 4), dtype=torch.int64, device=x.device
+ )
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = encoder_out_lens
+
+ lm = simple_lm_proj(decoder_out)
+ am = simple_am_proj(encoder_out)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+ lm=lm.float(),
+ am=am.float(),
+ symbols=y_padded,
+ termination_symbol=blank_id,
+ lm_only_scale=lm_scale,
+ am_only_scale=am_scale,
+ boundary=boundary,
+ reduction="sum",
+ return_grad=True,
+ )
+
+ # ranges : [B, T, prune_range]
+ ranges = k2.get_rnnt_prune_ranges(
+ px_grad=px_grad,
+ py_grad=py_grad,
+ boundary=boundary,
+ s_range=prune_range,
+ )
+
+ # am_pruned : [B, T, prune_range, encoder_dim]
+ # lm_pruned : [B, T, prune_range, decoder_dim]
+ am_pruned, lm_pruned = k2.do_rnnt_pruning(
+ am=joiner.encoder_proj(encoder_out),
+ lm=joiner.decoder_proj(decoder_out),
+ ranges=ranges,
+ )
+
+ # logits : [B, T, prune_range, vocab_size]
+
+ # project_input=False since we applied the decoder's input projections
+ # prior to do_rnnt_pruning (this is an optimization for speed).
+ logits = joiner(am_pruned, lm_pruned, project_input=False)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ pruned_loss = k2.rnnt_loss_pruned(
+ logits=logits.float(),
+ symbols=y_padded,
+ ranges=ranges,
+ termination_symbol=blank_id,
+ boundary=boundary,
+ reduction="sum",
+ )
+
+ return (simple_loss, pruned_loss)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/optim.py b/egs/aishell/ASR/pruned_transducer_stateless3/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
new file mode 100755
index 000000000..04a0a882a
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -0,0 +1,338 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless3/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="Maximum number of symbols per frame. "
+ "Use only when --method is greedy_search",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+ params.datatang_prob = 0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lens = [f.size(0) for f in features]
+ feature_lens = torch.tensor(feature_lens, device=device)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lens
+ )
+
+ num_waves = encoder_out.size(0)
+ hyp_list = []
+ logging.info(f"Using {params.method}")
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_list = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_list = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.method == "modified_beam_search":
+ hyp_list = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.method}"
+ )
+ hyp_list.append(hyp)
+
+ hyps = []
+ for hyp in hyp_list:
+ hyps.append([lexicon.token_table[i] for i in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
new file mode 100755
index 000000000..0e5291b21
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py
@@ -0,0 +1,1264 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+# Copyright 2021 (Pingfeng Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+./prepare.sh
+./prepare_aidatatang_200zh.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+
+./pruned_transducer_stateless3/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 0 \
+ --exp-dir pruned_transducer_stateless3/exp \
+ --max-duration 300 \
+ --datatang-prob 0.2
+
+# For mix precision training:
+
+./pruned_transducer_stateless3/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless3/exp \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from aidatatang_200zh import AIDatatang200zh
+from aishell import AIShell
+from asr_datamodule import AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=int,
+ default=12,
+ help="Number of conformer encoder layers..",
+ )
+
+ parser.add_argument(
+ "--dim-feedforward",
+ type=int,
+ default=2048,
+ help="Feedforward dimension of the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=int,
+ default=8,
+ help="Number of attention heads in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=int,
+ default=512,
+ help="Attention dimension in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless3/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need "
+ "to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--datatang-prob",
+ type=float,
+ default=0.2,
+ help="""The probability to select a batch from the
+ aidatatang_200zh dataset.
+ If it is set to 0, you don't need to download the data
+ for aidatatang_200zh.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 1000,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ # parameters for Noam
+ "model_warm_step": 3000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ if params.datatang_prob > 0:
+ decoder_datatang = get_decoder_model(params)
+ joiner_datatang = get_joiner_model(params)
+ else:
+ decoder_datatang = None
+ joiner_datatang = None
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ decoder_datatang=decoder_datatang,
+ joiner_datatang=joiner_datatang,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def is_aishell(c: Cut) -> bool:
+ """Return True if this cut is from the AIShell dataset.
+
+ Note:
+ During data preparation, we set the custom field in
+ the supervision segment of aidatatang_200zh to
+ dict(origin='aidatatang_200zh')
+ See ../local/process_aidatatang_200zh.py.
+ """
+ return c.supervisions[0].custom is None
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = (
+ model.device
+ if isinstance(model, DDP)
+ else next(model.parameters()).device
+ )
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ aishell = is_aishell(supervisions["cut"][0])
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ aishell=aishell,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ datatang_train_dl: Optional[torch.utils.data.DataLoader],
+ valid_dl: torch.utils.data.DataLoader,
+ rng: random.Random,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ aishell_tot_loss = MetricsTracker()
+ datatang_tot_loss = MetricsTracker()
+ tot_loss = MetricsTracker()
+
+ # index 0: for LibriSpeech
+ # index 1: for GigaSpeech
+ # This sets the probabilities for choosing which datasets
+ dl_weights = [1 - params.datatang_prob, params.datatang_prob]
+
+ iter_aishell = iter(train_dl)
+ if datatang_train_dl is not None:
+ iter_datatang = iter(datatang_train_dl)
+
+ batch_idx = 0
+
+ while True:
+ if datatang_train_dl is not None:
+ idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
+ dl = iter_aishell if idx == 0 else iter_datatang
+ else:
+ dl = iter_aishell
+
+ try:
+ batch = next(dl)
+ except StopIteration:
+ break
+ batch_idx += 1
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ aishell = is_aishell(batch["supervisions"]["cut"][0])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ if datatang_train_dl is not None:
+ tot_loss = (
+ tot_loss * (1 - 1 / params.reset_interval)
+ ) + loss_info
+
+ if aishell:
+ aishell_tot_loss = (
+ aishell_tot_loss * (1 - 1 / params.reset_interval)
+ ) + loss_info
+ prefix = "aishell" # for logging only
+ else:
+ datatang_tot_loss = (
+ datatang_tot_loss * (1 - 1 / params.reset_interval)
+ ) + loss_info
+ prefix = "datatang"
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(
+ batch, params=params, graph_compiler=graph_compiler
+ )
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ if datatang_train_dl is not None:
+ datatang_str = f"datatang_tot_loss[{datatang_tot_loss}], "
+ tot_loss_str = (
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ )
+ else:
+ tot_loss_str = ""
+ datatang_str = ""
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
+ f"{tot_loss_str}"
+ f"aishell_tot_loss[{aishell_tot_loss}], "
+ f"{datatang_str}"
+ f"batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer,
+ f"train/current_{prefix}_",
+ params.batch_idx_train,
+ )
+ if datatang_train_dl is not None:
+ # If it is None, tot_loss is the same as aishell_tot_loss.
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+ aishell_tot_loss.write_summary(
+ tb_writer, "train/aishell_tot_", params.batch_idx_train
+ )
+ if datatang_train_dl is not None:
+ datatang_tot_loss.write_summary(
+ tb_writer, "train/datatang_tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 12.0
+
+ return cuts
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ rng = random.Random(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ oov="",
+ )
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ aishell = AIShell(manifest_dir=args.manifest_dir)
+ train_cuts = aishell.train_cuts()
+ train_cuts = filter_short_and_long_utterances(train_cuts)
+
+ if args.enable_musan:
+ cuts_musan = load_manifest(
+ Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+ )
+ else:
+ cuts_musan = None
+
+ asr_datamodule = AsrDataModule(args)
+
+ train_dl = asr_datamodule.train_dataloaders(
+ train_cuts,
+ on_the_fly_feats=False,
+ cuts_musan=cuts_musan,
+ )
+
+ if params.datatang_prob > 0:
+ datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
+ train_datatang_cuts = datatang.train_cuts()
+ train_datatang_cuts = filter_short_and_long_utterances(
+ train_datatang_cuts
+ )
+ train_datatang_cuts = train_datatang_cuts.repeat(times=None)
+ datatang_train_dl = asr_datamodule.train_dataloaders(
+ train_datatang_cuts,
+ on_the_fly_feats=False,
+ cuts_musan=cuts_musan,
+ )
+ else:
+ datatang_train_dl = None
+ logging.info("Not using aidatatang_200zh for training")
+
+ valid_cuts = aishell.valid_cuts()
+ valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
+
+ for dl in [
+ train_dl,
+ # datatang_train_dl
+ ]:
+ if dl is not None:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+ if datatang_train_dl is not None:
+ datatang_train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ datatang_train_dl=datatang_train_dl,
+ valid_dl=valid_dl,
+ rng=rng,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = graph_compiler.texts_to_ids(supervisions["text"])
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=0.0 if params.start_epoch == 1 else 1.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(
+ batch, params=params, graph_compiler=graph_compiler
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ assert 0 <= args.datatang_prob < 1, args.datatang_prob
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 65caa656e..d24ba6bb7 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -1,4 +1,5 @@
# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -16,16 +17,17 @@
import argparse
+import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import List
-from lhotse import CutSet, Fbank, FbankConfig, load_manifest
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
- BucketingSampler,
CutConcatenate,
CutMix,
+ DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
@@ -91,7 +93,7 @@ class AishellAsrDataModule:
"--num-buckets",
type=int,
default=30,
- help="The number of buckets for the BucketingSampler"
+ help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
@@ -131,6 +133,12 @@ class AishellAsrDataModule:
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
group.add_argument(
"--return-cuts",
type=str2bool,
@@ -176,7 +184,7 @@ class AishellAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
- self.args.manifest_dir / "cuts_musan.json.gz"
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
@@ -210,10 +218,20 @@ class AishellAsrDataModule:
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
- num_frame_masks=2,
+ num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
@@ -250,14 +268,13 @@ class AishellAsrDataModule:
)
if self.args.bucketing_sampler:
- logging.info("Using BucketingSampler.")
- train_sampler = BucketingSampler(
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- bucket_method="equal_duration",
- drop_last=True,
+ drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
@@ -301,7 +318,7 @@ class AishellAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
- valid_sampler = BucketingSampler(
+ valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
@@ -325,8 +342,10 @@ class AishellAsrDataModule:
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
- sampler = BucketingSampler(
- cuts, max_duration=self.args.max_duration, shuffle=False
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
)
test_dl = DataLoader(
test,
@@ -339,17 +358,21 @@ class AishellAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
- cuts_train = load_manifest(
- self.args.manifest_dir / "cuts_train.json.gz"
+ cuts_train = load_manifest_lazy(
+ self.args.manifest_dir / "aishell_cuts_train.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
- return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz"
+ )
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
- return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell_cuts_test.jsonl.gz"
+ )
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index a0045115d..7619b0551 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -15,6 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Usage
+ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ ./tdnn_lstm_ctc/train.py \
+ --world-size 4 \
+ --num-epochs 20 \
+ --max-duration 300
+"""
import argparse
import logging
@@ -92,6 +100,13 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
return parser
@@ -507,7 +522,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
- fix_random_seed(42)
+ fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@@ -557,6 +572,7 @@ def run(rank, world_size, args):
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index 81d7708f9..66eb3eb63 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -110,7 +110,7 @@ class Conformer(Transformer):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
- lengths = ((x_lens - 1) // 2 - 1) // 2
+ lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
@@ -362,7 +362,7 @@ class RelPositionalEncoding(torch.nn.Module):
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
- # Suppose `i` means to the position of query vecotr and `j` means the
+ # Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i")[0][0]
+ # params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index dca084477..c2c6552a9 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
- embeding_out = self.embedding(y)
+ embedding_out = self.embedding(y)
if self.context_size > 1:
- embeding_out = embeding_out.permute(0, 2, 1)
+ embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
- embeding_out = F.pad(
- embeding_out, pad=(self.context_size - 1, 0)
+ embedding_out = F.pad(
+ embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
- assert embeding_out.size(-1) == self.context_size
- embeding_out = self.conv(embeding_out)
- embeding_out = embeding_out.permute(0, 2, 1)
- return embeding_out
+ assert embedding_out.size(-1) == self.context_size
+ embedding_out = self.conv(embedding_out)
+ embedding_out = embedding_out.permute(0, 2, 1)
+ return embedding_out
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 641555bdb..4c6519b96 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -22,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --lang-dir data/lang_char \
--epoch 20 \
--avg 10
@@ -33,21 +34,21 @@ To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
- cd /path/to/egs/librispeech/ASR
+ cd /path/to/egs/aishell/ASR
./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
- --bpe-model data/lang_bpe_500/bpe.model
+ --lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
import torch
+import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@@ -55,6 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
@@ -90,10 +92,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--lang-dir",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_char",
+ help="The lang dir",
)
parser.add_argument(
@@ -133,7 +135,7 @@ def get_params() -> AttributeDict:
return params
-def get_encoder_model(params: AttributeDict):
+def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@@ -147,7 +149,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
-def get_decoder_model(params: AttributeDict):
+def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@@ -157,7 +159,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
-def get_joiner_model(params: AttributeDict):
+def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@@ -165,7 +167,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
-def get_transducer_model(params: AttributeDict):
+def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
@@ -182,8 +184,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
- assert args.jit is False, "Support torchscript will be added later"
-
params = get_params()
params.update(vars(args))
@@ -193,12 +193,10 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ lexicon = Lexicon(params.lang_dir)
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@@ -225,6 +223,11 @@ def main():
model.eval()
if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py
index 2f0f9a183..994305fc1 100644
--- a/egs/aishell/ASR/transducer_stateless/model.py
+++ b/egs/aishell/ASR/transducer_stateless/model.py
@@ -14,15 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Note we use `rnnt_loss` from torchaudio, which exists only in
-torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
-"""
import k2
import torch
import torch.nn as nn
-import torchaudio
-import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
@@ -108,18 +102,13 @@ class Transducer(nn.Module):
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
- assert hasattr(torchaudio.functional, "rnnt_loss"), (
- f"Current torchaudio version: {torchaudio.__version__}\n"
- "Please install a version >= 0.10.0"
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (x.size(0), 4), dtype=torch.int64, device=x.device
)
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = x_lens
- loss = torchaudio.functional.rnnt_loss(
- logits=logits,
- targets=y_padded,
- logit_lengths=x_lens,
- target_lengths=y_lens,
- blank=blank_id,
- reduction="sum",
- )
+ loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary)
return loss
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index 65ac5f3ff..db89c4d67 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
import argparse
import logging
import math
-from typing import List
from pathlib import Path
+from typing import List
import kaldifeat
import torch
+import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
@@ -57,10 +58,10 @@ from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
-from icefall.env import get_env_info
-from icefall.utils import AttributeDict
-from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict
def get_parser():
@@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
return params
-def get_encoder_model(params: AttributeDict):
+def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
-def get_decoder_model(params: AttributeDict):
+def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
-def get_joiner_model(params: AttributeDict):
+def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
-def get_transducer_model(params: AttributeDict):
+def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index 7da8e28a1..d54157709 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -21,6 +21,7 @@
import argparse
import logging
+import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@@ -129,6 +130,13 @@ def get_parser():
"2 means tri-gram",
)
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
return parser
@@ -204,7 +212,7 @@ def get_params() -> AttributeDict:
return params
-def get_encoder_model(params: AttributeDict):
+def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
@@ -219,7 +227,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
-def get_decoder_model(params: AttributeDict):
+def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@@ -229,7 +237,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
-def get_joiner_model(params: AttributeDict):
+def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@@ -237,7 +245,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
-def get_transducer_model(params: AttributeDict):
+def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
@@ -379,7 +387,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
- info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@@ -534,7 +546,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
- fix_random_seed(42)
+ fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
@@ -558,7 +570,7 @@ def run(rank, world_size, args):
oov="",
)
- params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@@ -592,25 +604,23 @@ def run(rank, world_size, args):
train_cuts = aishell.train_cuts()
def remove_short_and_long_utt(c: Cut):
- # Keep only utterances with duration between 1 second and 20 seconds
- return 1.0 <= c.duration <= 20.0
-
- num_in_total = len(train_cuts)
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 12.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
- num_left = len(train_cuts)
- num_removed = num_in_total - num_left
- removed_percent = num_removed / num_in_total * 100
-
- logging.info(f"Before removing short and long utterances: {num_in_total}")
- logging.info(f"After removing short and long utterances: {num_left}")
- logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
-
train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/README.md b/egs/aishell/ASR/transducer_stateless_modified-2/README.md
new file mode 100644
index 000000000..b3c539670
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/README.md
@@ -0,0 +1,59 @@
+## Introduction
+
+The decoder, i.e., the prediction network, is from
+https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
+(Rnn-Transducer with Stateless Prediction Network)
+
+Different from `../transducer_stateless_modified`, this folder
+uses extra data, i.e., http://www.openslr.org/62/, during training.
+
+You can use the following command to start the training:
+
+```bash
+cd egs/aishell/ASR
+./prepare.sh --stop-stage 6
+./prepare_aidatatang_200zh.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2"
+
+./transducer_stateless_modified-2/train.py \
+ --world-size 3 \
+ --num-epochs 90 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 250 \
+ --lr-factor 2.0 \
+ --context-size 2 \
+ --modified-transducer-prob 0.25 \
+ --datatang-prob 0.2
+```
+
+To decode, you can use
+
+```bash
+for epoch in 89; do
+ for avg in 30 38; do
+ ./transducer_stateless_modified-2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method greedy_search \
+ --max-sym-per-frame 1
+ done
+done
+
+for epoch in 89; do
+ for avg in 38; do
+ ./transducer_stateless_modified-2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+ done
+done
+```
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/__init__.py b/egs/aishell/ASR/transducer_stateless_modified-2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py b/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py
new file mode 100644
index 000000000..26d4ee111
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/aidatatang_200zh.py
@@ -0,0 +1,53 @@
+# Copyright 2021 Piotr Żelasko
+# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+
+
+class AIDatatang200zh:
+ def __init__(self, manifest_dir: str):
+ """
+ Args:
+ manifest_dir:
+ It is expected to contain the following files::
+
+ - aidatatang_cuts_dev.jsonl.gz
+ - aidatatang_cuts_train.jsonl.gz
+ - aidatatang_cuts_test.jsonl.gz
+ """
+ self.manifest_dir = Path(manifest_dir)
+
+ def train_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aidatatang_cuts_train.jsonl.gz"
+ logging.info(f"About to get train cuts from {f}")
+ cuts_train = load_manifest_lazy(f)
+ return cuts_train
+
+ def valid_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aidatatang_cuts_valid.jsonl.gz"
+ logging.info(f"About to get valid cuts from {f}")
+ cuts_valid = load_manifest_lazy(f)
+ return cuts_valid
+
+ def test_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aidatatang_cuts_test.jsonl.gz"
+ logging.info(f"About to get test cuts from {f}")
+ cuts_test = load_manifest_lazy(f)
+ return cuts_test
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py b/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py
new file mode 100644
index 000000000..ddeca4d88
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/aishell.py
@@ -0,0 +1,53 @@
+# Copyright 2021 Piotr Żelasko
+# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+
+
+class AIShell:
+ def __init__(self, manifest_dir: str):
+ """
+ Args:
+ manifest_dir:
+ It is expected to contain the following files::
+
+ - aishell_cuts_dev.jsonl.gz
+ - aishell_cuts_train.jsonl.gz
+ - aishell_cuts_test.jsonl.gz
+ """
+ self.manifest_dir = Path(manifest_dir)
+
+ def train_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aishell_cuts_train.jsonl.gz"
+ logging.info(f"About to get train cuts from {f}")
+ cuts_train = load_manifest_lazy(f)
+ return cuts_train
+
+ def valid_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aishell_cuts_dev.jsonl.gz"
+ logging.info(f"About to get valid cuts from {f}")
+ cuts_valid = load_manifest_lazy(f)
+ return cuts_valid
+
+ def test_cuts(self) -> CutSet:
+ f = self.manifest_dir / "aishell_cuts_test.jsonl.gz"
+ logging.info(f"About to get test cuts from {f}")
+ cuts_test = load_manifest_lazy(f)
+ return cuts_test
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
new file mode 100644
index 000000000..838e53658
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -0,0 +1,301 @@
+# Copyright 2021 Piotr Żelasko
+# 2022 Xiaomi Corp. (authors: Fangjun Kuang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import inspect
+import logging
+from pathlib import Path
+from typing import Optional
+
+from lhotse import CutSet, Fbank, FbankConfig
+from lhotse.dataset import (
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import (
+ OnTheFlyFeatures,
+ PrecomputedFeatures,
+)
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class AsrDataModule:
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler "
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available. Used only in dev/test CutSet",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ on_the_fly_feats: bool,
+ cuts_musan: Optional[CutSet] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ Cuts for training.
+ cuts_musan:
+ If not None, it is the cuts for mixing.
+ on_the_fly_feats:
+ True to use OnTheFlyFeatures;
+ False to use PrecomputedFeatures.
+ """
+ transforms = []
+ if cuts_musan is not None:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ input_transforms = []
+
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if on_the_fly_feats
+ else PrecomputedFeatures()
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+
+ logging.info("About to create train dataloader")
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py b/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py
new file mode 120000
index 000000000..e188617a8
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py b/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py
new file mode 120000
index 000000000..88975988f
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/conformer.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/conformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
new file mode 100755
index 000000000..47265f846
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -0,0 +1,526 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./transducer_stateless_modified-2/decode.py \
+ --epoch 89 \
+ --avg 38 \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --max-duration 100 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./transducer_stateless_modified-2/decode.py \
+ --epoch 89 \
+ --avg 38 \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --max-duration 100 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./transducer_stateless_modified-2/decode.py \
+ --epoch 89 \
+ --avg 38 \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+(4) fast beam search
+./transducer_stateless_modified-2/decode.py \
+ --epoch 89 \
+ --avg 38 \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --max-duration 100 \
+ --decoding-method fast_beam_search \
+ --beam-size 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from aishell import AIShell
+from asr_datamodule import AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer_stateless_modified-2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ token_table:
+ It maps token ID to a string.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ hyp_tokens = []
+ batch_size = encoder_out.size(0)
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyp_tokens.append(hyp)
+
+ hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps a token ID to a string.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 10
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ token_table=token_table,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((list("".join(res[0])), list("".join(res[1]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ asr_datamodule = AsrDataModule(args)
+ aishell = AIShell(manifest_dir=args.manifest_dir)
+ test_cuts = aishell.test_cuts()
+ test_dl = asr_datamodule.test_dataloaders(test_cuts)
+
+ test_sets = ["test"]
+ test_dls = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ token_table=lexicon.token_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py b/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py
new file mode 120000
index 000000000..bdfcea5c2
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decoder.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py b/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py
new file mode 120000
index 000000000..a2a5f22cf
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/encoder_interface.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
new file mode 100755
index 000000000..3bd2ceb11
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -0,0 +1,249 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./transducer_stateless_modified-2/export.py \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --epoch 89 \
+ --avg 38
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `transducer_stateless_modified-2/decode.py`,
+you can do::
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell/ASR
+ ./transducer_stateless_modified-2/decode.py \
+ --exp-dir ./transducer_stateless_modified-2/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from model import Transducer
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=Path,
+ default=Path("transducer_stateless_modified-2/exp"),
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ # parameters for conformer
+ "feature_dim": 80,
+ "encoder_out_dim": 512,
+ "subsampling_factor": 4,
+ "attention_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ "vgg_frontend": False,
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.encoder_out_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.attention_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.encoder_out_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ input_dim=params.encoder_out_dim,
+ output_dim=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ )
+ return model
+
+
+def main():
+ args = get_parser().parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py b/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py
new file mode 120000
index 000000000..e9e435ecd
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/joiner.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/joiner.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/model.py b/egs/aishell/ASR/transducer_stateless_modified-2/model.py
new file mode 100644
index 000000000..086957d0b
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/model.py
@@ -0,0 +1,163 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import Optional
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+from icefall.utils import add_sos
+
+
+class Transducer(nn.Module):
+ """It implements https://arxiv.org/pdf/1211.3711.pdf
+ "Sequence Transduction with Recurrent Neural Networks"
+ """
+
+ def __init__(
+ self,
+ encoder: EncoderInterface,
+ decoder: nn.Module,
+ joiner: nn.Module,
+ decoder_datatang: Optional[nn.Module] = None,
+ joiner_datatang: Optional[nn.Module] = None,
+ ):
+ """
+ Args:
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
+ It returns two tensors: `logits` of shape (N, T, C) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, C). It should contain
+ one attribute: `blank_id`.
+ joiner:
+ It has two inputs with shapes: (N, T, C) and (N, U, C). Its
+ output shape is (N, T, U, C). Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ decoder_datatang:
+ The decoder for the aidatatang_200zh dataset.
+ joiner_datatang:
+ The joiner for the aidatatang_200zh dataset.
+ """
+ super().__init__()
+ assert isinstance(encoder, EncoderInterface), type(encoder)
+ assert hasattr(decoder, "blank_id")
+ if decoder_datatang is not None:
+ assert hasattr(decoder_datatang, "blank_id")
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joiner = joiner
+
+ self.decoder_datatang = decoder_datatang
+ self.joiner_datatang = joiner_datatang
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ aishell: bool = True,
+ modified_transducer_prob: float = 0.0,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ modified_transducer_prob:
+ The probability to use modified transducer loss.
+ Returns:
+ Return the transducer loss.
+ """
+ assert x.ndim == 3, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == x_lens.size(0) == y.dim0
+
+ encoder_out, x_lens = self.encoder(x, x_lens)
+ assert torch.all(x_lens > 0)
+
+ # Now for the decoder, i.e., the prediction network
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ blank_id = self.decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+ sos_y_padded = sos_y_padded.to(torch.int64)
+
+ if aishell:
+ decoder = self.decoder
+ joiner = self.joiner
+ else:
+ decoder = self.decoder_datatang
+ joiner = self.joiner_datatang
+
+ decoder_out = decoder(sos_y_padded)
+
+ # +1 here since a blank is prepended to each utterance.
+ logits = joiner(
+ encoder_out=encoder_out,
+ decoder_out=decoder_out,
+ encoder_out_len=x_lens,
+ decoder_out_len=y_lens + 1,
+ )
+
+ # rnnt_loss requires 0 padded targets
+ # Note: y does not start with SOS
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ # We don't put this `import` at the beginning of the file
+ # as it is required only in the training, not during the
+ # reference stage
+ import optimized_transducer
+
+ assert 0 <= modified_transducer_prob <= 1
+
+ if modified_transducer_prob == 0:
+ one_sym_per_frame = False
+ elif random.random() < modified_transducer_prob:
+ # random.random() returns a float in the range [0, 1)
+ one_sym_per_frame = True
+ else:
+ one_sym_per_frame = False
+
+ loss = optimized_transducer.transducer_loss(
+ logits=logits,
+ targets=y_padded,
+ logit_lengths=x_lens,
+ target_lengths=y_lens,
+ blank=blank_id,
+ reduction="sum",
+ one_sym_per_frame=one_sym_per_frame,
+ from_log_softmax=False,
+ )
+
+ return loss
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
new file mode 100755
index 000000000..a95a4bc52
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+
+(1) greedy search
+./transducer_stateless_modified-2/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) beam search
+./transducer_stateless_modified-2/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./transducer_stateless_modified-2/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./transducer_stateless_modified-2/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="Maximum number of symbols per frame. "
+ "Use only when --method is greedy_search",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lens = [f.size(0) for f in features]
+ feature_lens = torch.tensor(feature_lens, device=device)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lens
+ )
+
+ num_waves = encoder_out.size(0)
+ hyp_list = []
+ logging.info(f"Using {params.method}")
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_list = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_list = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.method == "modified_beam_search":
+ hyp_list = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.method}"
+ )
+ hyp_list.append(hyp)
+
+ hyps = []
+ for hyp in hyp_list:
+ hyps.append([lexicon.token_table[i] for i in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py b/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py
new file mode 120000
index 000000000..6fee09e58
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/subsampling.py
@@ -0,0 +1 @@
+../conformer_ctc/subsampling.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py b/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py
new file mode 120000
index 000000000..fbe1679ea
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/test_decoder.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/test_decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
new file mode 100755
index 000000000..225d0d709
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py
@@ -0,0 +1,876 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+# Copyright 2021 (Pingfeng Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./prepare.sh
+./prepare_aidatatang_200zh.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2"
+
+./transducer_stateless_modified-2/train.py \
+ --world-size 3 \
+ --num-epochs 90 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified-2/exp-2 \
+ --max-duration 250 \
+ --lr-factor 2.0 \
+ --context-size 2 \
+ --modified-transducer-prob 0.25 \
+ --datatang-prob 0.2
+"""
+
+
+import argparse
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from aidatatang_200zh import AIDatatang200zh
+from aishell import AIShell
+from asr_datamodule import AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer_stateless_modified-2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-factor",
+ type=float,
+ default=5.0,
+ help="The lr_factor for Noam optimizer",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--modified-transducer-prob",
+ type=float,
+ default=0.25,
+ help="""The probability to use modified transducer loss.
+ In modified transduer, it limits the maximum number of symbols
+ per frame to 1. See also the option --max-sym-per-frame in
+ transducer_stateless/decode.py
+ """,
+ )
+
+ parser.add_argument(
+ "--datatang-prob",
+ type=float,
+ default=0.2,
+ help="The probability to select a batch from the "
+ "aidatatang_200zh dataset",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - attention_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 800, # For the 100h subset, use 800
+ # parameters for conformer
+ "feature_dim": 80,
+ "encoder_out_dim": 512,
+ "subsampling_factor": 4,
+ "attention_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ "vgg_frontend": False,
+ # parameters for Noam
+ "warm_step": 80000, # For the 100h subset, use 8k
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.encoder_out_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.attention_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.encoder_out_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ input_dim=params.encoder_out_dim,
+ output_dim=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ decoder_datatang = get_decoder_model(params)
+ joiner_datatang = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ decoder_datatang=decoder_datatang,
+ joiner_datatang=joiner_datatang,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def is_aishell(c: Cut) -> bool:
+ """Return True if this cut is from the AIShell dataset.
+
+ Note:
+ During data preparation, we set the custom field in
+ the supervision segment of aidatatang_200zh to
+ dict(origin='aidatatang_200zh')
+ See ../local/process_aidatatang_200zh.py.
+ """
+ return c.supervisions[0].custom is None
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ aishell = is_aishell(supervisions["cut"][0])
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ aishell=aishell,
+ modified_transducer_prob=params.modified_transducer_prob,
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ datatang_train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ rng: random.Random,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ train_dl:
+ Dataloader for the training dataset.
+ datatang_train_dl:
+ Dataloader for the aidatatang_200zh training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ aishell_tot_loss = MetricsTracker()
+ datatang_tot_loss = MetricsTracker()
+ tot_loss = MetricsTracker()
+
+ # index 0: for LibriSpeech
+ # index 1: for GigaSpeech
+ # This sets the probabilities for choosing which datasets
+ dl_weights = [1 - params.datatang_prob, params.datatang_prob]
+
+ iter_aishell = iter(train_dl)
+ iter_datatang = iter(datatang_train_dl)
+
+ batch_idx = 0
+
+ while True:
+ idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
+ dl = iter_aishell if idx == 0 else iter_datatang
+
+ try:
+ batch = next(dl)
+ except StopIteration:
+ break
+ batch_idx += 1
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ aishell = is_aishell(batch["supervisions"]["cut"][0])
+
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+ if aishell:
+ aishell_tot_loss = (
+ aishell_tot_loss * (1 - 1 / params.reset_interval)
+ ) + loss_info
+ prefix = "aishell" # for logging only
+ else:
+ datatang_tot_loss = (
+ datatang_tot_loss * (1 - 1 / params.reset_interval)
+ ) + loss_info
+ prefix = "datatang"
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"aishell_tot_loss[{aishell_tot_loss}], "
+ f"datatang_tot_loss[{datatang_tot_loss}], "
+ f"batch size: {batch_size}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ if tb_writer is not None:
+ loss_info.write_summary(
+ tb_writer,
+ f"train/current_{prefix}_",
+ params.batch_idx_train,
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+ aishell_tot_loss.write_summary(
+ tb_writer, "train/aishell_tot_", params.batch_idx_train
+ )
+ datatang_tot_loss.write_summary(
+ tb_writer, "train/datatang_tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 12.0
+
+ return cuts
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ seed = 42
+ fix_random_seed(seed)
+ rng = random.Random(seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ oov="",
+ )
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+ model.device = device
+
+ optimizer = Noam(
+ model.parameters(),
+ model_size=params.attention_dim,
+ factor=params.lr_factor,
+ warm_step=params.warm_step,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ aishell = AIShell(manifest_dir=args.manifest_dir)
+
+ train_cuts = aishell.train_cuts()
+ train_cuts = filter_short_and_long_utterances(train_cuts)
+
+ datatang = AIDatatang200zh(manifest_dir=args.manifest_dir)
+ train_datatang_cuts = datatang.train_cuts()
+ train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts)
+ train_datatang_cuts = train_datatang_cuts.repeat(times=None)
+
+ if args.enable_musan:
+ cuts_musan = load_manifest(
+ Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+ )
+ else:
+ cuts_musan = None
+
+ asr_datamodule = AsrDataModule(args)
+
+ train_dl = asr_datamodule.train_dataloaders(
+ train_cuts,
+ on_the_fly_feats=False,
+ cuts_musan=cuts_musan,
+ )
+
+ datatang_train_dl = asr_datamodule.train_dataloaders(
+ train_datatang_cuts,
+ on_the_fly_feats=False,
+ cuts_musan=cuts_musan,
+ )
+
+ valid_cuts = aishell.valid_cuts()
+ valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
+
+ for dl in [
+ train_dl,
+ # datatang_train_dl
+ ]:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ train_dl.sampler.set_epoch(epoch)
+ datatang_train_dl.sampler.set_epoch(epoch)
+
+ cur_lr = optimizer._rate
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ if rank == 0:
+ logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ datatang_train_dl=datatang_train_dl,
+ valid_dl=valid_dl,
+ rng=rng,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ optimizer.zero_grad()
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ assert 0 <= args.datatang_prob < 1, args.datatang_prob
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py b/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py
new file mode 120000
index 000000000..4320d1105
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/transformer.py
@@ -0,0 +1 @@
+../transducer_stateless_modified/transformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/README.md b/egs/aishell/ASR/transducer_stateless_modified/README.md
new file mode 100644
index 000000000..9709eb9a0
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/README.md
@@ -0,0 +1,21 @@
+## Introduction
+
+The decoder, i.e., the prediction network, is from
+https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
+(Rnn-Transducer with Stateless Prediction Network)
+
+You can use the following command to start the training:
+
+```bash
+cd egs/aishell/ASR
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./transducer_stateless_modified/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified/exp \
+ --max-duration 250 \
+ --lr-factor 2.5
+```
diff --git a/egs/aishell/ASR/transducer_stateless_modified/__init__.py b/egs/aishell/ASR/transducer_stateless_modified/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py
new file mode 120000
index 000000000..a73848de9
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/asr_datamodule.py
@@ -0,0 +1 @@
+../conformer_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/beam_search.py b/egs/aishell/ASR/transducer_stateless_modified/beam_search.py
new file mode 120000
index 000000000..e188617a8
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/conformer.py b/egs/aishell/ASR/transducer_stateless_modified/conformer.py
new file mode 120000
index 000000000..8be0dc864
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/conformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
new file mode 100755
index 000000000..4773ebc7d
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -0,0 +1,527 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./transducer_stateless_modified/decode.py \
+ --epoch 14 \
+ --avg 7 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./transducer_stateless_modified/decode.py \
+ --epoch 14 \
+ --avg 7 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./transducer_stateless_modified/decode.py \
+ --epoch 14 \
+ --avg 7 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search
+./transducer_stateless_modified/decode.py \
+ --epoch 14 \
+ --avg 7 \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer_stateless_modified/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ token_table:
+ It maps token ID to a string.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ hyp_tokens = []
+ batch_size = encoder_out.size(0)
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyp_tokens.append(hyp)
+
+ hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps a token ID to a string.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 10
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ token_table=token_table,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((list("".join(res[0])), list("".join(res[1]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ aishell = AishellAsrDataModule(args)
+ test_cuts = aishell.test_cuts()
+ test_dl = aishell.test_dataloaders(test_cuts)
+
+ test_sets = ["test"]
+ test_dls = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ token_table=lexicon.token_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decoder.py b/egs/aishell/ASR/transducer_stateless_modified/decoder.py
new file mode 120000
index 000000000..82337f7ef
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py b/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
new file mode 100755
index 000000000..11335a834
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -0,0 +1,249 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./transducer_stateless_modified/export.py \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --epoch 64 \
+ --avg 33
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `transducer_stateless_modified/decode.py`,
+you can do::
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell/ASR
+ ./transducer_stateless_modified/decode.py \
+ --exp-dir ./transducer_stateless_modified/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from model import Transducer
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=10,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=Path,
+ default=Path("transducer_stateless_modified/exp"),
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ # parameters for conformer
+ "feature_dim": 80,
+ "encoder_out_dim": 512,
+ "subsampling_factor": 4,
+ "attention_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ "vgg_frontend": False,
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.encoder_out_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.attention_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.encoder_out_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ input_dim=params.encoder_out_dim,
+ output_dim=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ )
+ return model
+
+
+def main():
+ args = get_parser().parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/joiner.py b/egs/aishell/ASR/transducer_stateless_modified/joiner.py
new file mode 120000
index 000000000..1aec6bfaf
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/joiner.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/model.py b/egs/aishell/ASR/transducer_stateless_modified/model.py
new file mode 120000
index 000000000..16ddd93f0
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/model.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
new file mode 100755
index 000000000..262e822c2
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+
+(1) greedy search
+./transducer_stateless_modified/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) beam search
+./transducer_stateless_modified/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search
+./transducer_stateless_modified/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./transducer_stateless_modified/pretrained.py \
+ --checkpoint /path/to/pretrained.pt \
+ --lang-dir /path/to/lang_char \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang_char"),
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="Maximum number of symbols per frame. "
+ "Use only when --method is greedy_search",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"])
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lens = [f.size(0) for f in features]
+ feature_lens = torch.tensor(feature_lens, device=device)
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lens
+ )
+
+ num_waves = encoder_out.size(0)
+ hyp_list = []
+ logging.info(f"Using {params.method}")
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_list = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_list = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.method == "modified_beam_search":
+ hyp_list = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.method}"
+ )
+ hyp_list.append(hyp)
+
+ hyps = []
+ for hyp in hyp_list:
+ hyps.append([lexicon.token_table[i] for i in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/subsampling.py b/egs/aishell/ASR/transducer_stateless_modified/subsampling.py
new file mode 120000
index 000000000..6fee09e58
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/subsampling.py
@@ -0,0 +1 @@
+../conformer_ctc/subsampling.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py b/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py
new file mode 100755
index 000000000..fe0bdee70
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/test_decoder.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+To run this file, do:
+
+ cd icefall/egs/aishell/ASR
+ python ./transducer_stateless/test_decoder.py
+"""
+
+import torch
+from decoder import Decoder
+
+
+def test_decoder():
+ vocab_size = 3
+ blank_id = 0
+ embedding_dim = 128
+ context_size = 4
+
+ decoder = Decoder(
+ vocab_size=vocab_size,
+ embedding_dim=embedding_dim,
+ blank_id=blank_id,
+ context_size=context_size,
+ )
+ N = 100
+ U = 20
+ x = torch.randint(low=0, high=vocab_size, size=(N, U))
+ y = decoder(x)
+ assert y.shape == (N, U, embedding_dim)
+
+ # for inference
+ x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
+ y = decoder(x, need_pad=False)
+ assert y.shape == (N, 1, embedding_dim)
+
+
+def main():
+ test_decoder()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py
new file mode 100755
index 000000000..d3ffccafa
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/train.py
@@ -0,0 +1,753 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+# Copyright 2021 (Pingfeng Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2"
+
+./transducer_stateless_modified/train.py \
+ --world-size 3 \
+ --num-epochs 65 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_modified/exp \
+ --max-duration 250 \
+ --lr-factor 2.0 \
+ --context-size 2 \
+ --modified-transducer-prob 0.25
+"""
+
+
+import argparse
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer_stateless_modified/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-factor",
+ type=float,
+ default=5.0,
+ help="The lr_factor for Noam optimizer",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--modified-transducer-prob",
+ type=float,
+ default=0.25,
+ help="""The probability to use modified transducer loss.
+ In modified transduer, it limits the maximum number of symbols
+ per frame to 1. See also the option --max-sym-per-frame in
+ transducer_stateless/decode.py
+ """,
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - attention_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 800,
+ # parameters for conformer
+ "feature_dim": 80,
+ "encoder_out_dim": 512,
+ "subsampling_factor": 4,
+ "attention_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ "vgg_frontend": False,
+ # parameters for Noam
+ "warm_step": 80000, # For the 100h subset, use 8k
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.encoder_out_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.attention_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.encoder_out_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ input_dim=params.encoder_out_dim,
+ output_dim=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ modified_transducer_prob=params.modified_transducer_prob,
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+
+ if tb_writer is not None:
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(42)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ oov="",
+ )
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+ model.device = device
+
+ optimizer = Noam(
+ model.parameters(),
+ model_size=params.attention_dim,
+ factor=params.lr_factor,
+ warm_step=params.warm_step,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ aishell = AishellAsrDataModule(args)
+ train_cuts = aishell.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 12.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_dl = aishell.train_dataloaders(train_cuts)
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ train_dl.sampler.set_epoch(epoch)
+
+ cur_lr = optimizer._rate
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ if rank == 0:
+ logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ optimizer.zero_grad()
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/transducer_stateless_modified/transformer.py b/egs/aishell/ASR/transducer_stateless_modified/transformer.py
new file mode 120000
index 000000000..214afed39
--- /dev/null
+++ b/egs/aishell/ASR/transducer_stateless_modified/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/transformer.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/README.md b/egs/aishell2/ASR/README.md
new file mode 100644
index 000000000..ba38a1ec7
--- /dev/null
+++ b/egs/aishell2/ASR/README.md
@@ -0,0 +1,19 @@
+
+# Introduction
+
+This recipe includes some different ASR models trained with Aishell2.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|-----------------------------|
+| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless5 in librispeech recipe |
+
+The decoder in `transducer_stateless` is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md
new file mode 100644
index 000000000..7114bd5f5
--- /dev/null
+++ b/egs/aishell2/ASR/RESULTS.md
@@ -0,0 +1,89 @@
+## Results
+
+### Aishell2 char-based training results (Pruned Transducer 5)
+
+#### 2022-07-11
+
+Using the codes from this commit https://github.com/k2-fsa/icefall/pull/465.
+
+When training with context size equals to 1, the WERs are
+
+| | dev-ios | test-ios | comment |
+|------------------------------------|-------|----------|----------------------------------|
+| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
+| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
+
+The training command for reproducing is given below:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 4 \
+ --lang-dir data/lang_char \
+ --num-epochs 40 \
+ --start-epoch 1 \
+ --exp-dir /result \
+ --max-duration 300 \
+ --use-fp16 0 \
+ --num-encoder-layers 24 \
+ --dim-feedforward 1536 \
+ --nhead 8 \
+ --encoder-dim 384 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --context-size 1
+```
+
+The decoding command is:
+```bash
+for method in greedy_search modified_beam_search \
+ fast_beam_search fast_beam_search_nbest \
+ fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
+ ./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 600 \
+ --decoding-method $method \
+ --max-sym-per-frame 1 \
+ --num-encoder-layers 24 \
+ --dim-feedforward 1536 \
+ --nhead 8 \
+ --encoder-dim 384 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --context-size 1 \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5 \
+ --context-size 1 \
+ --use-averaged-model True
+done
+```
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/RXyX4QjQQVKjBS2eQ2Qajg/#scalars
+
+A pre-trained model and decoding logs can be found at
+
+When training with context size equals to 2, the WERs are
+
+| | dev-ios | test-ios | comment |
+|------------------------------------|-------|----------|----------------------------------|
+| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
+| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
+| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
+
+A pre-trained model and decoding logs can be found at
diff --git a/egs/aishell2/ASR/local/__init__.py b/egs/aishell2/ASR/local/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/aishell2/ASR/local/compile_lg.py b/egs/aishell2/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/aishell2/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
new file mode 100755
index 000000000..7bc969a1a
--- /dev/null
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aishell2 dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_aishell2(num_mel_bins: int = 80):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train",
+ "dev",
+ "test",
+ )
+ prefix = "aishell2"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
diff --git a/egs/aishell2/ASR/local/compute_fbank_musan.py b/egs/aishell2/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/aishell2/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/display_manifest_statistics.py b/egs/aishell2/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..14844cbf3
--- /dev/null
+++ b/egs/aishell2/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in transducer_stateless/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ paths = [
+ "./data/fbank/aishell2_cuts_train.jsonl.gz",
+ "./data/fbank/aishell2_cuts_dev.jsonl.gz",
+ "./data/fbank/aishell2_cuts_test.jsonl.gz",
+ ]
+
+ for path in paths:
+ print(f"Starting display the statistics for {path}")
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Starting display the statistics for ./data/fbank/aishell2_cuts_train.jsonl.gz
+Cuts count: 3026106
+Total duration (hours): 3021.2
+Speech duration (hours): 3021.2 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.6
+std 1.5
+min 0.3
+25% 2.4
+50% 3.3
+75% 4.4
+99% 8.2
+99.5% 8.9
+99.9% 10.6
+max 21.5
+Starting display the statistics for ./data/fbank/aishell2_cuts_dev.jsonl.gz
+Cuts count: 2500
+Total duration (hours): 2.0
+Speech duration (hours): 2.0 (100.0%)
+***
+Duration statistics (seconds):
+mean 2.9
+std 1.0
+min 1.1
+25% 2.2
+50% 2.7
+75% 3.4
+99% 6.3
+99.5% 6.7
+99.9% 7.8
+max 9.4
+Starting display the statistics for ./data/fbank/aishell2_cuts_test.jsonl.gz
+Cuts count: 5000
+Total duration (hours): 4.0
+Speech duration (hours): 4.0 (100.0%)
+***
+Duration statistics (seconds):
+mean 2.9
+std 1.0
+min 1.1
+25% 2.2
+50% 2.7
+75% 3.3
+99% 6.2
+99.5% 6.6
+99.9% 7.7
+max 8.5
+"""
diff --git a/egs/aishell2/ASR/local/prepare_char.py b/egs/aishell2/ASR/local/prepare_char.py
new file mode 120000
index 000000000..8779181e5
--- /dev/null
+++ b/egs/aishell2/ASR/local/prepare_char.py
@@ -0,0 +1 @@
+../../../aidatatang_200zh/ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/prepare_lang.py b/egs/aishell2/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..5d88dc1c8
--- /dev/null
+++ b/egs/aishell2/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/prepare_words.py b/egs/aishell2/ASR/local/prepare_words.py
new file mode 120000
index 000000000..e58fabb8f
--- /dev/null
+++ b/egs/aishell2/ASR/local/prepare_words.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/text2segments.py b/egs/aishell2/ASR/local/text2segments.py
new file mode 120000
index 000000000..7d68a39c3
--- /dev/null
+++ b/egs/aishell2/ASR/local/text2segments.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/local/text2token.py b/egs/aishell2/ASR/local/text2token.py
new file mode 120000
index 000000000..81e459d69
--- /dev/null
+++ b/egs/aishell2/ASR/local/text2token.py
@@ -0,0 +1 @@
+../../../aidatatang_200zh/ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh
new file mode 100755
index 000000000..06810bfdd
--- /dev/null
+++ b/egs/aishell2/ASR/prepare.sh
@@ -0,0 +1,181 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=30
+stage=0
+stop_stage=5
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, you need to apply aishell2 through
+# their official website.
+# https://www.aishelltech.com/aishell_2
+#
+# - $dl_dir/aishell2
+#
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/aishell2,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/aishell2 $dl_dir/aishell2
+ #
+ # The directory structure is
+ # aishell2/
+ # |-- AISHELL-2
+ # | |-- iOS
+ # |-- data
+ # |-- wav
+ # |-- trans.txt
+ # |-- dev
+ # |-- wav
+ # |-- trans.txt
+ # |-- test
+ # |-- wav
+ # |-- trans.txt
+
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/musan
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare aishell2 manifest"
+ # We assume that you have downloaded and unzip the aishell2 corpus
+ # to $dl_dir/aishell2
+ if [ ! -f data/manifests/.aishell2_manifests.done ]; then
+ mkdir -p data/manifests
+ lhotse prepare aishell2 $dl_dir/aishell2 data/manifests -j $nj
+ touch data/manifests/.aishell2_manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Compute fbank for aishell2"
+ if [ ! -f data/fbank/.aishell2.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aishell2.py
+ touch data/fbank/.aishell2.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
+fi
+
+lang_char_dir=data/lang_char
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Prepare char based lang"
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ # 2. chmod +x ./jq
+ # 3. cp jq /usr/bin
+ if [ ! -f $lang_char_dir/text ]; then
+ gunzip -c data/manifests/aishell2_supervisions_train.jsonl.gz \
+ | jq '.text' | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text
+ fi
+
+ # The implementation of chinese word segmentation for text,
+ # and it will take about 15 minutes.
+ # If you can't install paddle-tiny with python 3.8, please refer to
+ # https://github.com/fxsjy/jieba/issues/920
+ if [ ! -f $lang_char_dir/text_words_segmentation ]; then
+ python3 ./local/text2segments.py \
+ --input-file $lang_char_dir/text \
+ --output-file $lang_char_dir/text_words_segmentation
+ fi
+
+ cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
+
+ if [ ! -f $lang_char_dir/words.txt ]; then
+ python3 ./local/prepare_words.py \
+ --input-file $lang_char_dir/words_no_ids.txt \
+ --output-file $lang_char_dir/words.txt
+ fi
+
+ if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+ python3 ./local/prepare_char.py
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare G"
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order 3 \
+ -text $lang_char_dir/text_words_segmentation \
+ -lm $lang_char_dir/3-gram.unpruned.arpa
+ fi
+
+ mkdir -p data/lm
+ if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+ # It is used in building LG
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_char_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ $lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
+ fi
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compile LG"
+ ./local/compile_lg.py --lang-dir $lang_char_dir
+fi
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell2/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 100755
index 000000000..b7a21f579
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1,418 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AiShell2AsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. ios, android, mic).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to gen cuts from aishell2_cuts_train.jsonl.gz")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell2_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to gen cuts from aishell2_cuts_dev.jsonl.gz")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell2_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to gen cuts from aishell2_cuts_test.jsonl.gz")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell2_cuts_test.jsonl.gz"
+ )
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py b/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..f03bd34d3
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,791 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless5/decode.py \
+ --epoch 25 \
+ --avg 5 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import AiShell2AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_char",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ sentence = "".join([lexicon.word_table[i] for i in hyp])
+ hyps.append(list(sentence))
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ this_batch.append((ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AiShell2AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.unk_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(
+ params.vocab_size - 1, device=device
+ )
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ aishell2 = AiShell2AsrDataModule(args)
+
+ valid_cuts = aishell2.valid_cuts()
+ test_cuts = aishell2.test_cuts()
+
+ # use ios sets for dev and test
+ dev_dl = aishell2.valid_dataloaders(valid_cuts)
+ test_dl = aishell2.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..bc7bd71cb
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char
+ --epoch 25 \
+ --avg 5
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell2/ASR
+ ./pruned_transducer_stateless5/decode.py \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.unk_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py b/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/model.py b/egs/aishell2/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py b/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..09de1bece
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,342 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Path to lang.
+ """,
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.unk_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lengths
+ )
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.method}"
+ if params.method == "beam_search":
+ msg += f" with beam size {params.beam_size}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = "".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py b/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..838a0497f
--- /dev/null
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1131 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+# Copyright 2022 Nvidia (authors: Yuekai Zhang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 4 \
+ --lang-dir data/lang_char \
+ --num-epochs 40 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --max-duration 300 \
+ --use-fp16 0 \
+ --num-encoder-layers 24 \
+ --dim-feedforward 1536 \
+ --nhead 8 \
+ --encoder-dim 384 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+ --lang-dir data/lang_char \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AiShell2AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=int,
+ default=24,
+ help="Number of conformer encoder layers..",
+ )
+
+ parser.add_argument(
+ "--dim-feedforward",
+ type=int,
+ default=1536,
+ help="Feedforward dimension of the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=int,
+ default=8,
+ help="Number of attention heads in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=int,
+ default=384,
+ help="Attention dimension in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need "
+ "to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ # parameters for Noam
+ "model_warm_step": 3000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = (
+ model.device
+ if isinstance(model, DDP)
+ else next(model.parameters()).device
+ )
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ assert type(y) == list
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(
+ batch, params=params, graph_compiler=graph_compiler
+ )
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ aishell2 = AiShell2AsrDataModule(args)
+
+ train_cuts = aishell2.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 8 seconds
+ #
+ # Caution: There is a reason to select 8.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 8.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell2.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = aishell2.valid_cuts()
+ valid_dl = aishell2.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ warmup=0.0 if params.start_epoch == 1 else 1.0,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = graph_compiler.texts_to_ids(supervisions["text"])
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+ warmup: float,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=warmup,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(
+ batch, params=params, graph_compiler=graph_compiler
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ AiShell2AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell2/ASR/shared b/egs/aishell2/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/aishell2/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md
new file mode 100644
index 000000000..3744032f8
--- /dev/null
+++ b/egs/aishell4/ASR/README.md
@@ -0,0 +1,19 @@
+
+# Introduction
+
+This recipe includes some different ASR models trained with Aishell4 (including S, M and L three subsets).
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|-----------------------------|
+| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
+
+The decoder in `transducer_stateless` is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/aishell4/ASR/RESULTS.md b/egs/aishell4/ASR/RESULTS.md
new file mode 100644
index 000000000..9bd062f1d
--- /dev/null
+++ b/egs/aishell4/ASR/RESULTS.md
@@ -0,0 +1,117 @@
+## Results
+
+### Aishell4 Char training results (Pruned Transducer Stateless5)
+
+#### 2022-06-13
+
+Using the codes from this PR https://github.com/k2-fsa/icefall/pull/399.
+
+When use-averaged-model=False, the CERs are
+| | test | comment |
+|------------------------------------|------------|------------------------------------------|
+| greedy search | 30.05 | --epoch 30, --avg 25, --max-duration 800 |
+| modified beam search (beam size 4) | 29.16 | --epoch 30, --avg 25, --max-duration 800 |
+| fast beam search (set as default) | 29.20 | --epoch 30, --avg 25, --max-duration 1500|
+
+When use-averaged-model=True, the CERs are
+| | test | comment |
+|------------------------------------|------------|----------------------------------------------------------------------|
+| greedy search | 29.89 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
+| modified beam search (beam size 4) | 28.91 | --iter 36000, --avg 8, --max-duration 800 --use-averaged-model=True |
+| fast beam search (set as default) | 29.08 | --iter 36000, --avg 8, --max-duration 1500 --use-averaged-model=True |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 220 \
+ --save-every-n 4000
+
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/tjaVRKERS8C10SzhpBcxSQ/#scalars
+
+When use-averaged-model=False, the decoding command is:
+```
+epoch=30
+avg=25
+
+## greedy search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800
+
+## modified beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+## fast beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+When use-averaged-model=True, the decoding command is:
+```
+iter=36000
+avg=8
+
+## greedy search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800 \
+ --use-averaged-model True
+
+## modified beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800 \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ --use-averaged-model True
+
+## fast beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8 \
+ --use-averaged-model True
+```
+
+A pre-trained model and decoding logs can be found at
diff --git a/egs/aishell4/ASR/local/__init__.py b/egs/aishell4/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
new file mode 100755
index 000000000..09f885636
--- /dev/null
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aidatatang_200zh dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_aishell4(num_mel_bins: int = 80):
+ src_dir = Path("data/manifests/aishell4")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train_S",
+ "train_M",
+ "train_L",
+ "test",
+ )
+ prefix = "aishell4"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=ChunkedLilcomHdf5Writer,
+ )
+
+ logging.info("About splitting cuts into smaller chunks")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False,
+ min_duration=None,
+ )
+
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
diff --git a/egs/aishell4/ASR/local/compute_fbank_musan.py b/egs/aishell4/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/aishell4/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/local/display_manifest_statistics.py b/egs/aishell4/ASR/local/display_manifest_statistics.py
new file mode 100644
index 000000000..b79e55eef
--- /dev/null
+++ b/egs/aishell4/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,113 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+See the function `remove_short_and_long_utt()`
+in ../../../librispeech/ASR/transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest
+
+
+def main():
+ paths = [
+ "./data/fbank/cuts_train_S.json.gz",
+ "./data/fbank/cuts_train_M.json.gz",
+ "./data/fbank/cuts_train_L.json.gz",
+ "./data/fbank/cuts_test.json.gz",
+ ]
+
+ for path in paths:
+ print(f"Starting display the statistics for {path}")
+ cuts = load_manifest(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Starting display the statistics for ./data/fbank/cuts_train_S.json.gz
+Cuts count: 91995
+Total duration (hours): 95.8
+Speech duration (hours): 95.8 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.7
+std 7.1
+min 0.1
+25% 0.9
+50% 2.5
+75% 5.4
+99% 15.3
+99.5% 17.5
+99.9% 23.3
+max 1021.7
+Starting display the statistics for ./data/fbank/cuts_train_M.json.gz
+Cuts count: 177195
+Total duration (hours): 179.5
+Speech duration (hours): 179.5 (100.0%)
+***
+Duration statistics (seconds):
+mean 3.6
+std 6.4
+min 0.0
+25% 0.9
+50% 2.4
+75% 5.2
+99% 14.9
+99.5% 17.0
+99.9% 23.5
+max 990.4
+Starting display the statistics for ./data/fbank/cuts_train_L.json.gz
+Cuts count: 37572
+Total duration (hours): 49.1
+Speech duration (hours): 49.1 (100.0%)
+***
+Duration statistics (seconds):
+mean 4.7
+std 4.0
+min 0.2
+25% 1.6
+50% 3.7
+75% 6.7
+99% 17.5
+99.5% 19.8
+99.9% 26.2
+max 87.4
+Starting display the statistics for ./data/fbank/cuts_test.json.gz
+Cuts count: 10574
+Total duration (hours): 12.1
+Speech duration (hours): 12.1 (100.0%)
+***
+Duration statistics (seconds):
+mean 4.1
+std 3.4
+min 0.2
+25% 1.4
+50% 3.2
+75% 5.8
+99% 14.4
+99.5% 14.9
+99.9% 16.5
+max 17.9
+"""
diff --git a/egs/aishell4/ASR/local/prepare_char.py b/egs/aishell4/ASR/local/prepare_char.py
new file mode 100755
index 000000000..d9e47d17a
--- /dev/null
+++ b/egs/aishell4/ASR/local/prepare_char.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/text,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import re
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [
+ token2id[i] if i in token2id else token2id[""] for i in pieces
+ ]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
+ """Check if all the given tokens are in token symbol table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that contains all the valid tokens.
+ tokens:
+ A list of tokens.
+ Returns:
+ Return True if there is any token not in the token_sym_table,
+ otherwise False.
+ """
+ for tok in tokens:
+ if tok not in token_sym_table:
+ return True
+ return False
+
+
+def generate_lexicon(
+ token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
+ """Generate a lexicon from a word list and token_sym_table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that mapping token to token ids.
+ words:
+ A list of strings representing words.
+ Returns:
+ Return a dict whose keys are words and values are the corresponding
+ tokens.
+ """
+ lexicon = []
+ for word in words:
+ chars = list(word.strip(" \t"))
+ if contain_oov(token_sym_table, chars):
+ continue
+ lexicon.append((word, chars))
+
+ # The OOV word is
+ lexicon.append(("", [""]))
+ return lexicon
+
+
+def generate_tokens(text_file: str) -> Dict[str, int]:
+ """Generate tokens from the given text file.
+
+ Args:
+ text_file:
+ A file that contains text lines to generate tokens.
+ Returns:
+ Return a dict whose keys are tokens and values are token ids ranged
+ from 0 to len(keys) - 1.
+ """
+ tokens: Dict[str, int] = dict()
+ tokens[""] = 0
+ tokens[""] = 1
+ tokens[""] = 2
+ whitespace = re.compile(r"([ \t\r\n]+)")
+ with open(text_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = re.sub(whitespace, "", line)
+ chars = list(line)
+ for char in chars:
+ if char not in tokens:
+ tokens[char] = len(tokens)
+ return tokens
+
+
+def main():
+ lang_dir = Path("data/lang_char")
+ text_file = lang_dir / "text"
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = ["", "!SIL", "", "", "#0", "", ""]
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ token_sym_table = generate_tokens(text_file)
+
+ lexicon = generate_lexicon(token_sym_table, words)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py
new file mode 100755
index 000000000..e5ae89ec4
--- /dev/null
+++ b/egs/aishell4/ASR/local/prepare_lang.py
@@ -0,0 +1,390 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+ )
+ return parser.parse_args()
+
+
+def main():
+ out_dir = Path(get_args().lang_dir)
+ lexicon_filename = out_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / "tokens.txt", token2id)
+ write_mapping(out_dir / "words.txt", word2id)
+ write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/prepare_words.py b/egs/aishell4/ASR/local/prepare_words.py
new file mode 100755
index 000000000..65aca2983
--- /dev/null
+++ b/egs/aishell4/ASR/local/prepare_words.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input words.txt without ids:
+ - words_no_ids.txt
+and generates the new words.txt with related ids.
+ - words.txt
+"""
+
+
+import argparse
+import logging
+
+from tqdm import tqdm
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Prepare words.txt",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ default="data/lang_char/words_no_ids.txt",
+ type=str,
+ help="the words file without ids for WenetSpeech",
+ )
+ parser.add_argument(
+ "--output-file",
+ default="data/lang_char/words.txt",
+ type=str,
+ help="the words file with ids for WenetSpeech",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ add_words = [" 0", "!SIL 1", " 2", " 3"]
+ new_lines.extend(add_words)
+
+ logging.info("Starting reading the input file")
+ for i in tqdm(range(len(lines))):
+ x = lines[i]
+ idx = 4 + i
+ new_line = str(x.strip("\n")) + " " + str(idx)
+ new_lines.append(new_line)
+
+ logging.info("Starting writing the words.txt")
+ f_out = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_out.write(line)
+ f_out.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/test_prepare_lang.py b/egs/aishell4/ASR/local/test_prepare_lang.py
new file mode 100755
index 000000000..d4cf62bba
--- /dev/null
+++ b/egs/aishell4/ASR/local/test_prepare_lang.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+import os
+import tempfile
+
+import k2
+from prepare_lang import (
+ add_disambig_symbols,
+ generate_id_map,
+ get_phones,
+ get_words,
+ lexicon_to_fst,
+ read_lexicon,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def generate_lexicon_file() -> str:
+ fd, filename = tempfile.mkstemp()
+ os.close(fd)
+ s = """
+ !SIL SIL
+ SPN
+ SPN
+ f f
+ a a
+ foo f o o
+ bar b a r
+ bark b a r k
+ food f o o d
+ food2 f o o d
+ fo f o
+ """.strip()
+ with open(filename, "w") as f:
+ f.write(s)
+ return filename
+
+
+def test_read_lexicon(filename: str):
+ lexicon = read_lexicon(filename)
+ phones = get_phones(lexicon)
+ words = get_words(lexicon)
+ print(lexicon)
+ print(phones)
+ print(words)
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+ print(lexicon_disambig)
+ print("max disambig:", f"#{max_disambig}")
+
+ phones = ["", "SIL", "SPN"] + phones
+ for i in range(max_disambig + 1):
+ phones.append(f"#{i}")
+ words = [""] + words
+
+ phone2id = generate_id_map(phones)
+ word2id = generate_id_map(words)
+
+ print(phone2id)
+ print(word2id)
+
+ write_mapping("phones.txt", phone2id)
+ write_mapping("words.txt", word2id)
+
+ write_lexicon("a.txt", lexicon)
+ write_lexicon("a_disambig.txt", lexicon_disambig)
+
+ fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
+ fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa.draw("L.pdf", title="L")
+
+ fsa_disambig = lexicon_to_fst(
+ lexicon_disambig, phone2id=phone2id, word2id=word2id
+ )
+ fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
+
+
+def main():
+ filename = generate_lexicon_file()
+ test_read_lexicon(filename)
+ os.remove(filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/text2segments.py b/egs/aishell4/ASR/local/text2segments.py
new file mode 100644
index 000000000..3df727c67
--- /dev/null
+++ b/egs/aishell4/ASR/local/text2segments.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input "text", which refers to the transcript file for
+WenetSpeech:
+ - text
+and generates the output file text_word_segmentation which is implemented
+with word segmenting:
+ - text_words_segmentation
+"""
+
+
+import argparse
+
+import jieba
+from tqdm import tqdm
+
+jieba.enable_paddle()
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Chinese Word Segmentation for text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ default="data/lang_char/text",
+ type=str,
+ help="the input text file for WenetSpeech",
+ )
+ parser.add_argument(
+ "--output-file",
+ default="data/lang_char/text_words_segmentation",
+ type=str,
+ help="the text implemented with words segmenting for WenetSpeech",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ for i in tqdm(range(len(lines))):
+ x = lines[i].rstrip()
+ seg_list = jieba.cut(x, use_paddle=True)
+ new_line = " ".join(seg_list)
+ new_lines.append(new_line)
+
+ f_new = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_new.write(line)
+ f_new.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/text2token.py b/egs/aishell4/ASR/local/text2token.py
new file mode 100755
index 000000000..71be2a613
--- /dev/null
+++ b/egs/aishell4/ASR/local/text2token.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
+# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import codecs
+import re
+import sys
+from typing import List
+
+from pypinyin import lazy_pinyin, pinyin
+
+is_python2 = sys.version_info[0] == 2
+
+
+def exist_or_not(i, match_pos):
+ start_pos = None
+ end_pos = None
+ for pos in match_pos:
+ if pos[0] <= i < pos[1]:
+ start_pos = pos[0]
+ end_pos = pos[1]
+ break
+
+ return start_pos, end_pos
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert raw text to tokenized text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--nchar",
+ "-n",
+ default=1,
+ type=int,
+ help="number of characters to split, i.e., \
+ aabb -> a a b b with -n 1 and aa bb with -n 2",
+ )
+ parser.add_argument(
+ "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
+ )
+ parser.add_argument(
+ "--space", default="", type=str, help="space symbol"
+ )
+ parser.add_argument(
+ "--non-lang-syms",
+ "-l",
+ default=None,
+ type=str,
+ help="list of non-linguistic symobles, e.g., etc.",
+ )
+ parser.add_argument(
+ "text", type=str, default=False, nargs="?", help="input text"
+ )
+ parser.add_argument(
+ "--trans_type",
+ "-t",
+ type=str,
+ default="char",
+ choices=["char", "pinyin", "lazy_pinyin"],
+ help="""Transcript type. char/pinyin/lazy_pinyin""",
+ )
+ return parser
+
+
+def token2id(
+ texts, token_table, token_type: str = "lazy_pinyin", oov: str = ""
+) -> List[List[int]]:
+ """Convert token to id.
+ Args:
+ texts:
+ The input texts, it refers to the chinese text here.
+ token_table:
+ The token table is built based on "data/lang_xxx/token.txt"
+ token_type:
+ The type of token, such as "pinyin" and "lazy_pinyin".
+ oov:
+ Out of vocabulary token. When a word(token) in the transcript
+ does not exist in the token list, it is replaced with `oov`.
+
+ Returns:
+ The list of ids for the input texts.
+ """
+ if texts is None:
+ raise ValueError("texts can't be None!")
+ else:
+ oov_id = token_table[oov]
+ ids: List[List[int]] = []
+ for text in texts:
+ chars_list = list(str(text))
+ if token_type == "lazy_pinyin":
+ text = lazy_pinyin(chars_list)
+ sub_ids = [
+ token_table[txt] if txt in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ else: # token_type = "pinyin"
+ text = pinyin(chars_list)
+ sub_ids = [
+ token_table[txt[0]] if txt[0] in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ return ids
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ rs = []
+ if args.non_lang_syms is not None:
+ with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
+ nls = [x.rstrip() for x in f.readlines()]
+ rs = [re.compile(re.escape(x)) for x in nls]
+
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(
+ sys.stdin if is_python2 else sys.stdin.buffer
+ )
+
+ sys.stdout = codecs.getwriter("utf-8")(
+ sys.stdout if is_python2 else sys.stdout.buffer
+ )
+ line = f.readline()
+ n = args.nchar
+ while line:
+ x = line.split()
+ print(" ".join(x[: args.skip_ncols]), end=" ")
+ a = " ".join(x[args.skip_ncols :]) # noqa E203
+
+ # get all matched positions
+ match_pos = []
+ for r in rs:
+ i = 0
+ while i >= 0:
+ m = r.search(a, i)
+ if m:
+ match_pos.append([m.start(), m.end()])
+ i = m.end()
+ else:
+ break
+ if len(match_pos) > 0:
+ chars = []
+ i = 0
+ while i < len(a):
+ start_pos, end_pos = exist_or_not(i, match_pos)
+ if start_pos is not None:
+ chars.append(a[start_pos:end_pos])
+ i = end_pos
+ else:
+ chars.append(a[i])
+ i += 1
+ a = chars
+
+ if args.trans_type == "pinyin":
+ a = pinyin(list(str(a)))
+ a = [one[0] for one in a]
+
+ if args.trans_type == "lazy_pinyin":
+ a = lazy_pinyin(list(str(a)))
+
+ a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203
+
+ a_flat = []
+ for z in a:
+ a_flat.append("".join(z))
+
+ a_chars = "".join(a_flat)
+ print(a_chars)
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/local/text_normalize.py b/egs/aishell4/ASR/local/text_normalize.py
new file mode 100755
index 000000000..5650be502
--- /dev/null
+++ b/egs/aishell4/ASR/local/text_normalize.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input "text_full", which includes three transcript files
+(train_S, train_M and train_L) for AISHELL4:
+ - text_full
+and generates the output file text_normalize which is implemented
+to normalize text:
+ - text
+"""
+
+
+import argparse
+
+from tqdm import tqdm
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Normalizing for text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input",
+ default="data/lang_char/text_full",
+ type=str,
+ help="the input text files for AISHELL4",
+ )
+ parser.add_argument(
+ "--output",
+ default="data/lang_char/text",
+ type=str,
+ help="the text implemented with normalizer for AISHELL4",
+ )
+
+ return parser
+
+
+def text_normalize(str_line: str):
+ line = str_line.strip().rstrip("\n")
+ line = line.replace(" ", "")
+ line = line.replace("", "")
+ line = line.replace("<%>", "")
+ line = line.replace("<->", "")
+ line = line.replace("<$>", "")
+ line = line.replace("<#>", "")
+ line = line.replace("<_>", "")
+ line = line.replace("", "")
+ line = line.replace("`", "")
+ line = line.replace("&", "")
+ line = line.replace(",", "")
+ line = line.replace("A", "")
+ line = line.replace("a", "A")
+ line = line.replace("b", "B")
+ line = line.replace("c", "C")
+ line = line.replace("k", "K")
+ line = line.replace("t", "T")
+ line = line.replace(",", "")
+ line = line.replace("丶", "")
+ line = line.replace("。", "")
+ line = line.replace("、", "")
+ line = line.replace("?", "")
+ line = line.replace("·", "")
+ line = line.replace("*", "")
+ line = line.replace("!", "")
+ line = line.replace("$", "")
+ line = line.replace("+", "")
+ line = line.replace("-", "")
+ line = line.replace("\\", "")
+ line = line.replace("?", "")
+ line = line.replace("¥", "")
+ line = line.replace("%", "")
+ line = line.replace(".", "")
+ line = line.replace("<", "")
+ line = line.replace("&", "")
+ line = line.upper()
+
+ return line
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input
+ output_file = args.output
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ for i in tqdm(range(len(lines))):
+ new_line = text_normalize(lines[i])
+ new_lines.append(new_line)
+
+ f_new = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_new.write(line)
+ f_new.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
new file mode 100755
index 000000000..c351e3964
--- /dev/null
+++ b/egs/aishell4/ASR/prepare.sh
@@ -0,0 +1,160 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/aishell4
+# You can find four directories:train_S, train_M, train_L and test.
+# You can download it from https://openslr.org/111/
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/aishell4,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/aishell4 $dl_dir/aishell4
+ #
+ if [ ! -f $dl_dir/aishell4/train_L ]; then
+ lhotse download aishell4 $dl_dir/aishell4
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/musan
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare aishell4 manifest"
+ # We assume that you have downloaded the aishell4 corpus
+ # to $dl_dir/aishell4
+ if [ ! -f data/manifests/aishell4/.manifests.done ]; then
+ mkdir -p data/manifests/aishell4
+ lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
+ touch data/manifests/aishell4/.manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Process aishell4"
+ if [ ! -f data/fbank/aishell4/.fbank.done ]; then
+ mkdir -p data/fbank/aishell4
+ lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
+ touch data/fbank/aishell4/.fbank.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for aishell4"
+ if [ ! -f data/fbank/.aishell4.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aishell4.py
+ touch data/fbank/.aishell4.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare char based lang"
+ lang_char_dir=data/lang_char
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ gunzip -c data/manifests/aishell4/aishell4_supervisions_train_S.jsonl.gz \
+ | jq ".text" | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text_S
+
+ gunzip -c data/manifests/aishell4/aishell4_supervisions_train_M.jsonl.gz \
+ | jq ".text" | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text_M
+
+ gunzip -c data/manifests/aishell4/aishell4_supervisions_train_L.jsonl.gz \
+ | jq ".text" | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text_L
+
+ for r in text_S text_M text_L ; do
+ cat $lang_char_dir/$r >> $lang_char_dir/text_full
+ done
+
+ # Prepare text normalize
+ python ./local/text_normalize.py \
+ --input $lang_char_dir/text_full \
+ --output $lang_char_dir/text
+
+ # Prepare words segments
+ python ./local/text2segments.py \
+ --input $lang_char_dir/text \
+ --output $lang_char_dir/text_words_segmentation
+
+ cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
+ | sort -u | sed "/^$/d" \
+ | uniq > $lang_char_dir/words_no_ids.txt
+
+ # Prepare words.txt
+ if [ ! -f $lang_char_dir/words.txt ]; then
+ ./local/prepare_words.py \
+ --input-file $lang_char_dir/words_no_ids.txt \
+ --output-file $lang_char_dir/words.txt
+ fi
+
+ if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+ ./local/prepare_char.py
+ fi
+fi
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/__init__.py b/egs/aishell4/ASR/pruned_transducer_stateless5/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
new file mode 100644
index 000000000..7aa53ddda
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -0,0 +1,448 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class Aishell4AsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=300,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=30000,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_dl.sampler.load_state_dict(sampler_state_dict)
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ rank=0,
+ world_size=1,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ rank=0,
+ world_size=1,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_S_cuts(self) -> CutSet:
+ logging.info("About to get S train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell4_cuts_train_S.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_M_cuts(self) -> CutSet:
+ logging.info("About to get M train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell4_cuts_train_M.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_L_cuts(self) -> CutSet:
+ logging.info("About to get L train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell4_cuts_train_L.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ # Aishell4 doesn't have dev data, here use test to replace dev.
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "aishell4_cuts_test.jsonl.gz"
+ )
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py b/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py
new file mode 120000
index 000000000..ed78bd4bb
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/beam_search.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py b/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
new file mode 120000
index 000000000..c7c1a4b6e
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
new file mode 100755
index 000000000..d329410e1
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py
@@ -0,0 +1,636 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+When use-averaged-model=True, usage:
+(1) greedy search
+./pruned_transducer_stateless5/decode.py \
+ --iter 36000 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 800 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+
+(2) modified beam search
+./pruned_transducer_stateless5/decode.py \
+ --iter 36000 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 800 \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ --use-averaged-model True
+
+(3) fast beam search
+./pruned_transducer_stateless5/decode.py \
+ --iter 36000 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --max-duration 800 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8 \
+ --use-averaged-model True
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import Aishell4AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from local.text_normalize import text_normalize
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list(str(text).replace(" ", "")) for text in texts]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ this_batch.append((ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ Aishell4AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += (
+ f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ )
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+
+ model.to(device)
+ model.eval()
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ def text_normalize_for_cut(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = text.strip("\n").strip("\t")
+ c.supervisions[0].text = text_normalize(text)
+ return c
+
+ aishell4 = Aishell4AsrDataModule(args)
+ test_cuts = aishell4.test_cuts()
+ test_cuts = test_cuts.map(text_normalize_for_cut)
+ test_dl = aishell4.test_dataloaders(test_cuts)
+
+ test_sets = ["test"]
+ test_dl = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py
new file mode 120000
index 000000000..8a5e07bd5
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decoder.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py
new file mode 120000
index 000000000..2fc10439b
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/encoder_interface.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
new file mode 100755
index 000000000..993341131
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless5/export.py \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --epoch 20 \
+ --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless5/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/aishell4/ASR
+ ./pruned_transducer_stateless5/decode.py \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(
+ params.exp_dir, iteration=-params.iter
+ )[: params.avg + 1]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+
+ model.eval()
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py b/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py
new file mode 120000
index 000000000..f31b5fd9b
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/joiner.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/model.py b/egs/aishell4/ASR/pruned_transducer_stateless5/model.py
new file mode 120000
index 000000000..be059ba7c
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/model.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py b/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py
new file mode 120000
index 000000000..661206562
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/optim.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
new file mode 100755
index 000000000..1fa893637
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -0,0 +1,358 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+When use-averaged-model=True, usage:
+
+(1) greedy search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir data/lang_char \
+ --decoding-method greedy_search \
+ --use-averaged-model True \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) beam search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir data/lang_char \
+ --use-averaged-model True \
+ --decoding-method beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) modified beam search (not suggest)
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir data/lang_char \
+ --use-averaged-model True \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(4) fast beam search
+./pruned_transducer_stateless5/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
+ --lang-dir data/lang_char \
+ --use-averaged-model True \
+ --decoding-method fast_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
+./pruned_transducer_stateless5/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Path to lang.
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --decoding-method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lengths
+ )
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.decoding_method}"
+ if params.decoding_method == "beam_search":
+ msg += f" with beam size {params.beam_size}"
+ logging.info(msg)
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding-method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py b/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py
new file mode 120000
index 000000000..be7b111c6
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/scaling.py
@@ -0,0 +1 @@
+../../../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py b/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py
new file mode 100755
index 000000000..d42c3b4f4
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/test_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+ cd icefall/egs/aishell4/ASR
+ python ./pruned_transducer_stateless5/test_model.py
+"""
+
+from train import get_params, get_transducer_model
+
+
+def test_model_1():
+ params = get_params()
+ params.vocab_size = 500
+ params.blank_id = 0
+ params.context_size = 2
+ params.num_encoder_layers = 24
+ params.dim_feedforward = 1536 # 384 * 4
+ params.encoder_dim = 384
+ model = get_transducer_model(params)
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+
+# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
+def test_model_M():
+ params = get_params()
+ params.vocab_size = 500
+ params.blank_id = 0
+ params.context_size = 2
+ params.num_encoder_layers = 18
+ params.dim_feedforward = 1024
+ params.encoder_dim = 256
+ params.nhead = 4
+ params.decoder_dim = 512
+ params.joiner_dim = 512
+ model = get_transducer_model(params)
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+
+def main():
+ # test_model_1()
+ test_model_M()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
new file mode 100755
index 000000000..0a48b9059
--- /dev/null
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py
@@ -0,0 +1,1108 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --full-libri 1 \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import Aishell4AsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from local.text_normalize import text_normalize
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=int,
+ default=24,
+ help="Number of conformer encoder layers..",
+ )
+
+ parser.add_argument(
+ "--dim-feedforward",
+ type=int,
+ default=1536,
+ help="Feedforward dimension of the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=int,
+ default=8,
+ help="Number of attention heads in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=int,
+ default=384,
+ help="Attention dimension in the conformer encoder layer.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless5/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need "
+ "to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 100,
+ "valid_interval": 200,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ # parameters for Noam
+ "model_warm_step": 400, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = (
+ model.device
+ if isinstance(model, DDP)
+ else next(model.parameters()).device
+ )
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ if type(y) == list:
+ y = k2.RaggedTensor(y).to(device)
+ else:
+ y = y.to(device)
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+ # print(batch["supervisions"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ aishell4 = Aishell4AsrDataModule(args)
+ # Combine all of the training data
+ train_cuts = aishell4.train_S_cuts()
+ train_cuts += aishell4.train_M_cuts()
+ train_cuts += aishell4.train_L_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 20.0
+
+ def text_normalize_for_cut(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = text.strip("\n").strip("\t")
+ c.supervisions[0].text = text_normalize(text)
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_cuts = train_cuts.map(text_normalize_for_cut)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell4.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = aishell4.valid_cuts()
+ valid_cuts = valid_cuts.map(text_normalize_for_cut)
+ valid_dl = aishell4.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ Aishell4AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell4/ASR/shared b/egs/aishell4/ASR/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/aishell4/ASR/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/README.md b/egs/alimeeting/ASR/README.md
new file mode 100644
index 000000000..257fe38d5
--- /dev/null
+++ b/egs/alimeeting/ASR/README.md
@@ -0,0 +1,19 @@
+
+# Introduction
+
+This recipe includes some different ASR models trained with Alimeeting (far).
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|-----------------------------|
+| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
+
+The decoder in `transducer_stateless` is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/alimeeting/ASR/RESULTS.md b/egs/alimeeting/ASR/RESULTS.md
new file mode 100644
index 000000000..745795a20
--- /dev/null
+++ b/egs/alimeeting/ASR/RESULTS.md
@@ -0,0 +1,71 @@
+## Results
+
+### Alimeeting Char training results (Pruned Transducer Stateless2)
+
+#### 2022-06-01
+
+Using the codes from this PR https://github.com/k2-fsa/icefall/pull/378.
+
+The WERs are
+| | eval | test | comment |
+|------------------------------------|------------|------------|------------------------------------------|
+| greedy search | 31.77 | 34.66 | --epoch 29, --avg 18, --max-duration 100 |
+| modified beam search (beam size 4) | 30.38 | 33.02 | --epoch 29, --avg 18, --max-duration 100 |
+| fast beam search (set as default) | 31.39 | 34.25 | --epoch 29, --avg 18, --max-duration 1500|
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 220 \
+ --save-every-n 1000
+
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/AoqgSvZKTZCJhJbOuG3W6g/#scalars
+
+The decoding command is:
+```
+epoch=29
+avg=18
+
+## greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 100
+
+## modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+## fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+A pre-trained model and decoding logs can be found at
diff --git a/egs/alimeeting/ASR/local/__init__.py b/egs/alimeeting/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
new file mode 100755
index 000000000..2ff473c60
--- /dev/null
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the aishell dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_alimeeting(num_mel_bins: int = 80):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_jobs = min(15, os.cpu_count())
+
+ dataset_parts = (
+ "train",
+ "eval",
+ "test",
+ )
+
+ prefix = "alimeeting"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ cur_num_jobs = num_jobs if ex is None else 80
+ cur_num_jobs = min(cur_num_jobs, len(cut_set))
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=cur_num_jobs,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ logging.info("About splitting cuts into smaller chunks")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False,
+ min_duration=None,
+ )
+ cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
diff --git a/egs/alimeeting/ASR/local/compute_fbank_musan.py b/egs/alimeeting/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/alimeeting/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/local/display_manifest_statistics.py b/egs/alimeeting/ASR/local/display_manifest_statistics.py
new file mode 100644
index 000000000..16cdecc91
--- /dev/null
+++ b/egs/alimeeting/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,96 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+See the function `remove_short_and_long_utt()`
+in ../../../librispeech/ASR/transducer/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ paths = [
+ "./data/fbank/alimeeting_cuts_train.jsonl.gz",
+ "./data/fbank/alimeeting_cuts_eval.jsonl.gz",
+ "./data/fbank/alimeeting_cuts_test.jsonl.gz",
+ ]
+
+ for path in paths:
+ print(f"Starting display the statistics for {path}")
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Starting display the statistics for ./data/fbank/alimeeting_cuts_train.jsonl.gz
+Cuts count: 559092
+Total duration (hours): 424.6
+Speech duration (hours): 424.6 (100.0%)
+***
+Duration statistics (seconds):
+mean 2.7
+std 3.0
+min 0.0
+25% 0.7
+50% 1.7
+75% 3.6
+99% 13.6
+99.5% 14.7
+99.9% 16.2
+max 284.3
+Starting display the statistics for ./data/fbank/alimeeting_cuts_eval.jsonl.gz
+Cuts count: 6457
+Total duration (hours): 4.9
+Speech duration (hours): 4.9 (100.0%)
+***
+Duration statistics (seconds):
+mean 2.7
+std 3.1
+min 0.1
+25% 0.6
+50% 1.6
+75% 3.5
+99% 13.6
+99.5% 14.1
+99.9% 14.7
+max 15.8
+Starting display the statistics for ./data/fbank/alimeeting_cuts_test.jsonl.gz
+Cuts count: 16358
+Total duration (hours): 12.5
+Speech duration (hours): 12.5 (100.0%)
+***
+Duration statistics (seconds):
+mean 2.7
+std 2.9
+min 0.1
+25% 0.7
+50% 1.7
+75% 3.5
+99% 13.7
+99.5% 14.2
+99.9% 14.8
+max 15.7
+"""
diff --git a/egs/alimeeting/ASR/local/prepare_char.py b/egs/alimeeting/ASR/local/prepare_char.py
new file mode 100755
index 000000000..d9e47d17a
--- /dev/null
+++ b/egs/alimeeting/ASR/local/prepare_char.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/text,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import re
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [
+ token2id[i] if i in token2id else token2id[""] for i in pieces
+ ]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
+ """Check if all the given tokens are in token symbol table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that contains all the valid tokens.
+ tokens:
+ A list of tokens.
+ Returns:
+ Return True if there is any token not in the token_sym_table,
+ otherwise False.
+ """
+ for tok in tokens:
+ if tok not in token_sym_table:
+ return True
+ return False
+
+
+def generate_lexicon(
+ token_sym_table: Dict[str, int], words: List[str]
+) -> Lexicon:
+ """Generate a lexicon from a word list and token_sym_table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that mapping token to token ids.
+ words:
+ A list of strings representing words.
+ Returns:
+ Return a dict whose keys are words and values are the corresponding
+ tokens.
+ """
+ lexicon = []
+ for word in words:
+ chars = list(word.strip(" \t"))
+ if contain_oov(token_sym_table, chars):
+ continue
+ lexicon.append((word, chars))
+
+ # The OOV word is
+ lexicon.append(("", [""]))
+ return lexicon
+
+
+def generate_tokens(text_file: str) -> Dict[str, int]:
+ """Generate tokens from the given text file.
+
+ Args:
+ text_file:
+ A file that contains text lines to generate tokens.
+ Returns:
+ Return a dict whose keys are tokens and values are token ids ranged
+ from 0 to len(keys) - 1.
+ """
+ tokens: Dict[str, int] = dict()
+ tokens[""] = 0
+ tokens[""] = 1
+ tokens[""] = 2
+ whitespace = re.compile(r"([ \t\r\n]+)")
+ with open(text_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = re.sub(whitespace, "", line)
+ chars = list(line)
+ for char in chars:
+ if char not in tokens:
+ tokens[char] = len(tokens)
+ return tokens
+
+
+def main():
+ lang_dir = Path("data/lang_char")
+ text_file = lang_dir / "text"
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = ["", "!SIL", "", "", "#0", "", ""]
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ token_sym_table = generate_tokens(text_file)
+
+ lexicon = generate_lexicon(token_sym_table, words)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py
new file mode 100755
index 000000000..e5ae89ec4
--- /dev/null
+++ b/egs/alimeeting/ASR/local/prepare_lang.py
@@ -0,0 +1,390 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir", type=str, help="The lang dir, data/lang_phone"
+ )
+ return parser.parse_args()
+
+
+def main():
+ out_dir = Path(get_args().lang_dir)
+ lexicon_filename = out_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / "tokens.txt", token2id)
+ write_mapping(out_dir / "words.txt", word2id)
+ write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/local/prepare_words.py b/egs/alimeeting/ASR/local/prepare_words.py
new file mode 100755
index 000000000..65aca2983
--- /dev/null
+++ b/egs/alimeeting/ASR/local/prepare_words.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input words.txt without ids:
+ - words_no_ids.txt
+and generates the new words.txt with related ids.
+ - words.txt
+"""
+
+
+import argparse
+import logging
+
+from tqdm import tqdm
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Prepare words.txt",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ default="data/lang_char/words_no_ids.txt",
+ type=str,
+ help="the words file without ids for WenetSpeech",
+ )
+ parser.add_argument(
+ "--output-file",
+ default="data/lang_char/words.txt",
+ type=str,
+ help="the words file with ids for WenetSpeech",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ add_words = [" 0", "!SIL 1", " 2", " 3"]
+ new_lines.extend(add_words)
+
+ logging.info("Starting reading the input file")
+ for i in tqdm(range(len(lines))):
+ x = lines[i]
+ idx = 4 + i
+ new_line = str(x.strip("\n")) + " " + str(idx)
+ new_lines.append(new_line)
+
+ logging.info("Starting writing the words.txt")
+ f_out = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_out.write(line)
+ f_out.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/local/test_prepare_lang.py b/egs/alimeeting/ASR/local/test_prepare_lang.py
new file mode 100755
index 000000000..d4cf62bba
--- /dev/null
+++ b/egs/alimeeting/ASR/local/test_prepare_lang.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+import os
+import tempfile
+
+import k2
+from prepare_lang import (
+ add_disambig_symbols,
+ generate_id_map,
+ get_phones,
+ get_words,
+ lexicon_to_fst,
+ read_lexicon,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def generate_lexicon_file() -> str:
+ fd, filename = tempfile.mkstemp()
+ os.close(fd)
+ s = """
+ !SIL SIL
+ SPN
+ SPN
+ f f
+ a a
+ foo f o o
+ bar b a r
+ bark b a r k
+ food f o o d
+ food2 f o o d
+ fo f o
+ """.strip()
+ with open(filename, "w") as f:
+ f.write(s)
+ return filename
+
+
+def test_read_lexicon(filename: str):
+ lexicon = read_lexicon(filename)
+ phones = get_phones(lexicon)
+ words = get_words(lexicon)
+ print(lexicon)
+ print(phones)
+ print(words)
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+ print(lexicon_disambig)
+ print("max disambig:", f"#{max_disambig}")
+
+ phones = ["", "SIL", "SPN"] + phones
+ for i in range(max_disambig + 1):
+ phones.append(f"#{i}")
+ words = [""] + words
+
+ phone2id = generate_id_map(phones)
+ word2id = generate_id_map(words)
+
+ print(phone2id)
+ print(word2id)
+
+ write_mapping("phones.txt", phone2id)
+ write_mapping("words.txt", word2id)
+
+ write_lexicon("a.txt", lexicon)
+ write_lexicon("a_disambig.txt", lexicon_disambig)
+
+ fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
+ fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa.draw("L.pdf", title="L")
+
+ fsa_disambig = lexicon_to_fst(
+ lexicon_disambig, phone2id=phone2id, word2id=word2id
+ )
+ fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
+ fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
+ fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
+
+
+def main():
+ filename = generate_lexicon_file()
+ test_read_lexicon(filename)
+ os.remove(filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/local/text2segments.py b/egs/alimeeting/ASR/local/text2segments.py
new file mode 100644
index 000000000..3df727c67
--- /dev/null
+++ b/egs/alimeeting/ASR/local/text2segments.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input "text", which refers to the transcript file for
+WenetSpeech:
+ - text
+and generates the output file text_word_segmentation which is implemented
+with word segmenting:
+ - text_words_segmentation
+"""
+
+
+import argparse
+
+import jieba
+from tqdm import tqdm
+
+jieba.enable_paddle()
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Chinese Word Segmentation for text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ default="data/lang_char/text",
+ type=str,
+ help="the input text file for WenetSpeech",
+ )
+ parser.add_argument(
+ "--output-file",
+ default="data/lang_char/text_words_segmentation",
+ type=str,
+ help="the text implemented with words segmenting for WenetSpeech",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = args.input_file
+ output_file = args.output_file
+
+ f = open(input_file, "r", encoding="utf-8")
+ lines = f.readlines()
+ new_lines = []
+ for i in tqdm(range(len(lines))):
+ x = lines[i].rstrip()
+ seg_list = jieba.cut(x, use_paddle=True)
+ new_line = " ".join(seg_list)
+ new_lines.append(new_line)
+
+ f_new = open(output_file, "w", encoding="utf-8")
+ for line in new_lines:
+ f_new.write(line)
+ f_new.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/local/text2token.py b/egs/alimeeting/ASR/local/text2token.py
new file mode 100755
index 000000000..71be2a613
--- /dev/null
+++ b/egs/alimeeting/ASR/local/text2token.py
@@ -0,0 +1,195 @@
+#!/usr/bin/env python3
+# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe)
+# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import codecs
+import re
+import sys
+from typing import List
+
+from pypinyin import lazy_pinyin, pinyin
+
+is_python2 = sys.version_info[0] == 2
+
+
+def exist_or_not(i, match_pos):
+ start_pos = None
+ end_pos = None
+ for pos in match_pos:
+ if pos[0] <= i < pos[1]:
+ start_pos = pos[0]
+ end_pos = pos[1]
+ break
+
+ return start_pos, end_pos
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert raw text to tokenized text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--nchar",
+ "-n",
+ default=1,
+ type=int,
+ help="number of characters to split, i.e., \
+ aabb -> a a b b with -n 1 and aa bb with -n 2",
+ )
+ parser.add_argument(
+ "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
+ )
+ parser.add_argument(
+ "--space", default="", type=str, help="space symbol"
+ )
+ parser.add_argument(
+ "--non-lang-syms",
+ "-l",
+ default=None,
+ type=str,
+ help="list of non-linguistic symobles, e.g., etc.",
+ )
+ parser.add_argument(
+ "text", type=str, default=False, nargs="?", help="input text"
+ )
+ parser.add_argument(
+ "--trans_type",
+ "-t",
+ type=str,
+ default="char",
+ choices=["char", "pinyin", "lazy_pinyin"],
+ help="""Transcript type. char/pinyin/lazy_pinyin""",
+ )
+ return parser
+
+
+def token2id(
+ texts, token_table, token_type: str = "lazy_pinyin", oov: str = ""
+) -> List[List[int]]:
+ """Convert token to id.
+ Args:
+ texts:
+ The input texts, it refers to the chinese text here.
+ token_table:
+ The token table is built based on "data/lang_xxx/token.txt"
+ token_type:
+ The type of token, such as "pinyin" and "lazy_pinyin".
+ oov:
+ Out of vocabulary token. When a word(token) in the transcript
+ does not exist in the token list, it is replaced with `oov`.
+
+ Returns:
+ The list of ids for the input texts.
+ """
+ if texts is None:
+ raise ValueError("texts can't be None!")
+ else:
+ oov_id = token_table[oov]
+ ids: List[List[int]] = []
+ for text in texts:
+ chars_list = list(str(text))
+ if token_type == "lazy_pinyin":
+ text = lazy_pinyin(chars_list)
+ sub_ids = [
+ token_table[txt] if txt in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ else: # token_type = "pinyin"
+ text = pinyin(chars_list)
+ sub_ids = [
+ token_table[txt[0]] if txt[0] in token_table else oov_id
+ for txt in text
+ ]
+ ids.append(sub_ids)
+ return ids
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ rs = []
+ if args.non_lang_syms is not None:
+ with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
+ nls = [x.rstrip() for x in f.readlines()]
+ rs = [re.compile(re.escape(x)) for x in nls]
+
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(
+ sys.stdin if is_python2 else sys.stdin.buffer
+ )
+
+ sys.stdout = codecs.getwriter("utf-8")(
+ sys.stdout if is_python2 else sys.stdout.buffer
+ )
+ line = f.readline()
+ n = args.nchar
+ while line:
+ x = line.split()
+ print(" ".join(x[: args.skip_ncols]), end=" ")
+ a = " ".join(x[args.skip_ncols :]) # noqa E203
+
+ # get all matched positions
+ match_pos = []
+ for r in rs:
+ i = 0
+ while i >= 0:
+ m = r.search(a, i)
+ if m:
+ match_pos.append([m.start(), m.end()])
+ i = m.end()
+ else:
+ break
+ if len(match_pos) > 0:
+ chars = []
+ i = 0
+ while i < len(a):
+ start_pos, end_pos = exist_or_not(i, match_pos)
+ if start_pos is not None:
+ chars.append(a[start_pos:end_pos])
+ i = end_pos
+ else:
+ chars.append(a[i])
+ i += 1
+ a = chars
+
+ if args.trans_type == "pinyin":
+ a = pinyin(list(str(a)))
+ a = [one[0] for one in a]
+
+ if args.trans_type == "lazy_pinyin":
+ a = lazy_pinyin(list(str(a)))
+
+ a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203
+
+ a_flat = []
+ for z in a:
+ a_flat.append("".join(z))
+
+ a_chars = "".join(a_flat)
+ print(a_chars)
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
new file mode 100755
index 000000000..eb2ac697d
--- /dev/null
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -0,0 +1,133 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/alimeeting
+# This directory contains the following files downloaded from
+# https://openslr.org/62/
+#
+# - Train_Ali_far.tar.gz
+# - Train_Ali_near.tar.gz
+# - Test_Ali.tar.gz
+# - Eval_Ali.tar.gz
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ if [ ! -f $dl_dir/alimeeting/Train_Ali_far.tar.gz ]; then
+ lhotse download ali-meeting $dl_dir/alimeeting
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare alimeeting manifest"
+ # We assume that you have downloaded the alimeeting corpus
+ # to $dl_dir/alimeeting
+ if [ ! -f data/manifests/alimeeting/.manifests.done ]; then
+ mkdir -p data/manifests/alimeeting
+ lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting
+ touch data/manifests/alimeeting/.manifests.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Process alimeeting"
+ if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
+ mkdir -p data/fbank/alimeeting
+ lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ if [ ! -f data/manifests/.musan_manifests.done ]; then
+ log "It may take 6 minutes"
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_manifests.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ if [ ! -f data/fbank/.msuan.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.msuan.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for alimeeting"
+ if [ ! -f data/fbank/.alimeeting.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_alimeeting.py
+ touch data/fbank/.alimeeting.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare char based lang"
+ lang_char_dir=data/lang_char
+ mkdir -p $lang_char_dir
+
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ gunzip -c data/manifests/alimeeting/supervisions_train.jsonl.gz \
+ | jq ".text" | sed 's/"//g' \
+ | ./local/text2token.py -t "char" > $lang_char_dir/text
+
+ # Prepare words segments
+ python ./local/text2segments.py \
+ --input $lang_char_dir/text \
+ --output $lang_char_dir/text_words_segmentation
+
+ cat $lang_char_dir/text_words_segmentation | sed "s/ /\n/g" \
+ | sort -u | sed "/^$/d" \
+ | uniq > $lang_char_dir/words_no_ids.txt
+
+ # Prepare words.txt
+ if [ ! -f $lang_char_dir/words.txt ]; then
+ ./local/prepare_words.py \
+ --input-file $lang_char_dir/words_no_ids.txt \
+ --output-file $lang_char_dir/words.txt
+ fi
+
+ if [ ! -f $lang_char_dir/L_disambig.pt ]; then
+ ./local/prepare_char.py
+ fi
+fi
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/__init__.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
new file mode 100644
index 000000000..bf6faad7a
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -0,0 +1,421 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import torch
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ load_manifest,
+ load_manifest_lazy,
+ set_caching_enabled,
+)
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+set_caching_enabled(False)
+torch.set_num_threads(1)
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AlimeetingAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/dev/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=300,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=30000,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_dl.sampler.load_state_dict(sampler_state_dict)
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+
+ from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
+
+ dev_iter_dataset = IterableDatasetWrapper(
+ dataset=validate,
+ sampler=valid_sampler,
+ )
+ valid_dl = DataLoader(
+ dev_iter_dataset,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
+
+ test_iter_dataset = IterableDatasetWrapper(
+ dataset=test,
+ sampler=sampler,
+ )
+ test_dl = DataLoader(
+ test_iter_dataset,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "alimeeting_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "alimeeting_cuts_eval.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "alimeeting_cuts_test.jsonl.gz"
+ )
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py
new file mode 120000
index 000000000..a65957180
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
new file mode 100755
index 000000000..cb455838e
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py
@@ -0,0 +1,615 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+When training with the far data, usage:
+(1) greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 29 \
+ --avg 18 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 100 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 29 \
+ --avg 18 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 100 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 29 \
+ --avg 18 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--batch",
+ type=int,
+ default=None,
+ help="It specifies the batch checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--avg-last-n",
+ type=int,
+ default=0,
+ help="""If positive, --epoch and --avg are ignored and it
+ will use the last n checkpoints exp_dir/checkpoint-xxx.pt
+ where xxx is the number of processed batches while
+ saving that checkpoint.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 100
+ else:
+ log_interval = 50
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list(str(text).replace(" ", "")) for text in texts]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ this_batch.append((ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AlimeetingAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.avg_last_n > 0:
+ filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ elif params.batch is not None:
+ filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt"
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints([filenames], device=device))
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ average = average_checkpoints(filenames, device=device)
+ checkpoint = {"model": average}
+ torch.save(
+ checkpoint,
+ "pruned_transducer_stateless2/exp/pretrained_epoch_29_avg_18.pt",
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # Note: Please use "pip install webdataset==0.1.103"
+ # for installing the webdataset.
+ import glob
+ import os
+
+ from lhotse import CutSet
+ from lhotse.dataset.webdataset import export_to_webdataset
+
+ alimeeting = AlimeetingAsrDataModule(args)
+
+ dev = "eval"
+ test = "test"
+
+ if not os.path.exists(f"{dev}/shared-0.tar"):
+ os.makedirs(dev)
+ dev_cuts = alimeeting.valid_cuts()
+ export_to_webdataset(
+ dev_cuts,
+ output_path=f"{dev}/shared-%d.tar",
+ shard_size=300,
+ )
+
+ if not os.path.exists(f"{test}/shared-0.tar"):
+ os.makedirs(test)
+ test_cuts = alimeeting.test_cuts()
+ export_to_webdataset(
+ test_cuts,
+ output_path=f"{test}/shared-%d.tar",
+ shard_size=300,
+ )
+
+ dev_shards = [
+ str(path)
+ for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
+ ]
+ cuts_dev_webdataset = CutSet.from_webdataset(
+ dev_shards,
+ split_by_worker=True,
+ split_by_node=True,
+ shuffle_shards=True,
+ )
+
+ test_shards = [
+ str(path)
+ for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
+ ]
+ cuts_test_webdataset = CutSet.from_webdataset(
+ test_shards,
+ split_by_worker=True,
+ split_by_node=True,
+ shuffle_shards=True,
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ return 1.0 <= c.duration
+
+ cuts_dev_webdataset = cuts_dev_webdataset.filter(remove_short_and_long_utt)
+ cuts_test_webdataset = cuts_test_webdataset.filter(
+ remove_short_and_long_utt
+ )
+
+ dev_dl = alimeeting.valid_dataloaders(cuts_dev_webdataset)
+ test_dl = alimeeting.test_dataloaders(cuts_test_webdataset)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
new file mode 100644
index 000000000..8beec1b8a
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -0,0 +1,181 @@
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless2/export.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --epoch 29 \
+ --avg 18
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless2/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/alimeeting/ASR
+ ./pruned_transducer_stateless2/decode.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --lang-dir data/lang_char
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import torch
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.lexicon import Lexicon
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.eval()
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
new file mode 100644
index 000000000..93b1e1f57
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -0,0 +1,347 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Here, the far data is used for training, usage:
+
+(1) greedy search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --decoding-method greedy_search \
+ --max-sym-per-frame 1 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./pruned_transducer_stateless2/pretrained.py \
+ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
+ --lang-dir ./data/lang_char \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8 \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
+
+Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
+./pruned_transducer_stateless2/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import get_params, get_transducer_model
+
+from icefall.lexicon import Lexicon
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Path to lang.
+ """,
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="Used only when --method is beam_search and modified_beam_search ",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert sample_rate == expected_sample_rate, (
+ f"expected sample rate: {expected_sample_rate}. "
+ f"Given: {sample_rate}"
+ )
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_transducer_model(params)
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features, batch_first=True, padding_value=math.log(1e-10)
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ with torch.no_grad():
+ encoder_out, encoder_out_lens = model.encoder(
+ x=features, x_lens=feature_lengths
+ )
+
+ hyps = []
+ msg = f"Using {params.decoding_method}"
+ logging.info(msg)
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
new file mode 100644
index 000000000..81a0ede7f
--- /dev/null
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py
@@ -0,0 +1,972 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 220 \
+ --save-every-n 1000
+
+# For mix precision training:
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --lang-dir data/lang_char \
+ --max-duration 220 \
+ --save-every-n 1000
+ --use-fp16 True
+
+"""
+
+import argparse
+import logging
+import os
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AlimeetingAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12359,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=20,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+ Explanation of options saved in `params`:
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+ - best_train_epoch: It is the epoch that has the best training loss.
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+ - subsampling_factor: The subsampling factor for the model.
+ - encoder_dim: Hidden dim for multi-head attention model.
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 10,
+ "log_interval": 1,
+ "reset_interval": 200,
+ "valid_interval": 400,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for decoder
+ "decoder_dim": 512,
+ # parameters for joiner
+ "joiner_dim": 512,
+ # parameters for Noam
+ "model_warm_step": 200,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`.
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 0:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+
+ y = graph_compiler.texts_to_ids(texts)
+ if type(y) == list:
+ y = k2.RaggedTensor(y).to(device)
+ else:
+ y = y.to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+ model.device = device
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2 ** 22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ alimeeting = AlimeetingAsrDataModule(args)
+
+ train_cuts = alimeeting.train_cuts()
+ valid_cuts = alimeeting.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 15.0 seconds
+ #
+ # Caution: There is a reason to select 10.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 15.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = alimeeting.valid_dataloaders(valid_cuts)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = alimeeting.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ if not params.print_diagnostics and params.start_batch == 0:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ scheduler.step_epoch(epoch)
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ AlimeetingAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.lang_dir = Path(args.lang_dir)
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/ASR/shared b/egs/alimeeting/ASR/shared
new file mode 120000
index 000000000..3a3b28f96
--- /dev/null
+++ b/egs/alimeeting/ASR/shared
@@ -0,0 +1 @@
+../../../egs/aishell/ASR/shared
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore
new file mode 100644
index 000000000..5592679cc
--- /dev/null
+++ b/egs/gigaspeech/ASR/.gitignore
@@ -0,0 +1 @@
+log-*
diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md
new file mode 100644
index 000000000..32a0457c6
--- /dev/null
+++ b/egs/gigaspeech/ASR/README.md
@@ -0,0 +1,21 @@
+# GigaSpeech
+GigaSpeech, an evolving, multi-domain English
+speech recognition corpus with 10,000 hours of high quality labeled
+audio, collected from audiobooks, podcasts
+and YouTube, covering both read and spontaneous speaking styles,
+and a variety of topics, such as arts, science, sports, etc. More details can be found: https://github.com/SpeechColab/GigaSpeech
+
+## Download
+
+Apply for the download credentials and download the dataset by following https://github.com/SpeechColab/GigaSpeech#download. Then create a symlink
+```bash
+ln -sfv /path/to/GigaSpeech download/GigaSpeech
+```
+
+## Performance Record
+| | Dev | Test |
+|--------------------------------|-------|-------|
+| `conformer_ctc` | 10.47 | 10.58 |
+| `pruned_transducer_stateless2` | 10.40 | 10.51 |
+
+See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details.
diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md
new file mode 100644
index 000000000..7ab565844
--- /dev/null
+++ b/egs/gigaspeech/ASR/RESULTS.md
@@ -0,0 +1,152 @@
+## Results
+### GigaSpeech BPE training results (Pruned Transducer 2)
+
+#### 2022-05-12
+
+#### Conformer encoder + embedding decoder
+
+Conformer encoder + non-recurrent decoder. The encoder is a
+reworked version of the conformer encoder, with many changes. The
+decoder contains only an embedding layer, a Conv1d (with kernel
+size 2) and a linear layer (to transform tensor dim). k2 pruned
+RNN-T loss is used.
+
+The best WER, as of 2022-05-12, for the gigaspeech is below
+
+Results are:
+
+| | Dev | Test |
+|----------------------|-------|-------|
+| greedy search | 10.51 | 10.73 |
+| fast beam search | 10.50 | 10.69 |
+| modified beam search | 10.40 | 10.51 |
+
+To reproduce the above result, use the following commands for training:
+
+```bash
+cd egs/gigaspeech/ASR
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+./pruned_transducer_stateless2/train.py \
+ --max-duration 120 \
+ --num-workers 1 \
+ --world-size 8 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --use-fp16 True
+```
+
+and the following commands for decoding:
+
+```bash
+# greedy search
+./pruned_transducer_stateless2/decode.py \
+ --iter 3488000 \
+ --avg 20 \
+ --decoding-method greedy_search \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --max-duration 600
+
+# fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --iter 3488000 \
+ --avg 20 \
+ --decoding-method fast_beam_search \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --max-duration 600
+
+# modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --iter 3488000 \
+ --avg 15 \
+ --decoding-method modified_beam_search \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --max-duration 600
+```
+
+Pretrained model is available at
+
+
+The tensorboard log for training is available at
+
+
+### GigaSpeech BPE training results (Conformer-CTC)
+
+#### 2022-04-06
+
+The best WER, as of 2022-04-06, for the gigaspeech is below
+
+Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring:
+
+| | Dev | Test |
+|-----|-------|-------|
+| WER | 10.47 | 10.58 |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| ngram_lm_scale | attention_scale |
+|----------------|-----------------|
+| 0.5 | 1.3 |
+
+
+To reproduce the above result, use the following commands for training:
+
+```bash
+cd egs/gigaspeech/ASR
+./prepare.sh
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+./conformer_ctc/train.py \
+ --max-duration 120 \
+ --num-workers 1 \
+ --world-size 8 \
+ --exp-dir conformer_ctc/exp_500 \
+ --lang-dir data/lang_bpe_500
+```
+
+and the following command for decoding:
+
+```bash
+./conformer_ctc/decode.py \
+ --epoch 18 \
+ --avg 6 \
+ --method attention-decoder \
+ --num-paths 1000 \
+ --exp-dir conformer_ctc/exp_500 \
+ --lang-dir data/lang_bpe_500 \
+ --max-duration 20 \
+ --num-workers 1
+```
+
+Results using HLG decoding + whole lattice rescoring:
+
+| | Dev | Test |
+|-----|-------|-------|
+| WER | 10.51 | 10.62 |
+
+Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
+| lm_scale |
+|----------|
+| 0.2 |
+
+To reproduce the above result, use the training commands above, and the following command for decoding:
+
+```bash
+./conformer_ctc/decode.py \
+ --epoch 18 \
+ --avg 6 \
+ --method whole-lattice-rescoring \
+ --num-paths 1000 \
+ --exp-dir conformer_ctc/exp_500 \
+ --lang-dir data/lang_bpe_500 \
+ --max-duration 20 \
+ --num-workers 1
+```
+Note: the `whole-lattice-rescoring` method is about twice as fast as the `attention-decoder` method, with slightly worse WER.
+
+Pretrained model is available at
+
+
+The tensorboard log for training is available at
+
diff --git a/egs/gigaspeech/ASR/conformer_ctc/__init__.py b/egs/gigaspeech/ASR/conformer_ctc/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
new file mode 100644
index 000000000..d78e26240
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
@@ -0,0 +1,376 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class GigaSpeechAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it "
+ "with training dataset. ",
+ )
+
+ # GigaSpeech specific arguments
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="XL",
+ help="Select the GigaSpeech subset (XS|S|M|L|XL)",
+ )
+ group.add_argument(
+ "--small-dev",
+ type=str2bool,
+ default=False,
+ help="Should we use only 1000 utterances for dev "
+ "(speeds up training)",
+ )
+
+ def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=2,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info(f"About to get train_{self.args.subset} cuts")
+ path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
+ cuts_train = CutSet.from_jsonl_lazy(path)
+ return cuts_train
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_DEV.jsonl.gz"
+ )
+ if self.args.small_dev:
+ return cuts_valid.subset(first=1000)
+ else:
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py
new file mode 100644
index 000000000..36e0c7aea
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py
@@ -0,0 +1,930 @@
+#!/usr/bin/env python3
+# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+from transformer import Supervisions, Transformer, encoder_padding_mask
+
+
+class Conformer(Transformer):
+ """
+ Args:
+ num_features (int): Number of input features
+ num_classes (int): Number of output classes
+ subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
+ d_model (int): attention dimension
+ nhead (int): number of head
+ dim_feedforward (int): feedforward dimention
+ num_encoder_layers (int): number of encoder layers
+ num_decoder_layers (int): number of decoder layers
+ dropout (float): dropout rate
+ cnn_module_kernel (int): Kernel size of convolution module
+ normalize_before (bool): whether to use layer_norm before the first block.
+ vgg_frontend (bool): whether to use vgg frontend.
+ """
+
+ def __init__(
+ self,
+ num_features: int,
+ num_classes: int,
+ subsampling_factor: int = 4,
+ d_model: int = 256,
+ nhead: int = 4,
+ dim_feedforward: int = 2048,
+ num_encoder_layers: int = 12,
+ num_decoder_layers: int = 6,
+ dropout: float = 0.1,
+ cnn_module_kernel: int = 31,
+ normalize_before: bool = True,
+ vgg_frontend: bool = False,
+ use_feat_batchnorm: Union[float, bool] = 0.1,
+ ) -> None:
+ super(Conformer, self).__init__(
+ num_features=num_features,
+ num_classes=num_classes,
+ subsampling_factor=subsampling_factor,
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ num_encoder_layers=num_encoder_layers,
+ num_decoder_layers=num_decoder_layers,
+ dropout=dropout,
+ normalize_before=normalize_before,
+ vgg_frontend=vgg_frontend,
+ use_feat_batchnorm=use_feat_batchnorm,
+ )
+
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+ use_conv_batchnorm = True
+ if isinstance(use_feat_batchnorm, float):
+ use_conv_batchnorm = False
+ encoder_layer = ConformerEncoderLayer(
+ d_model,
+ nhead,
+ dim_feedforward,
+ dropout,
+ cnn_module_kernel,
+ normalize_before,
+ use_conv_batchnorm,
+ )
+ self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = nn.LayerNorm(d_model)
+ else:
+ # Note: TorchScript detects that self.after_norm could be used inside forward()
+ # and throws an error without this change.
+ self.after_norm = identity
+
+ def run_encoder(
+ self, x: Tensor, supervisions: Optional[Supervisions] = None
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """
+ Args:
+ x:
+ The model input. Its shape is (N, T, C).
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling
+ It is read directly from the batch, without any sorting. It is used
+ to compute encoder padding mask, which is used as memory key padding
+ mask for the decoder.
+
+ Returns:
+ Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
+ Tensor: Mask tensor of dimension (batch_size, input_length)
+ """
+ x = self.encoder_embed(x)
+ x, pos_emb = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
+ mask = encoder_padding_mask(x.size(0), supervisions)
+ if mask is not None:
+ mask = mask.to(x.device)
+ x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
+
+ if self.normalize_before:
+ x = self.after_norm(x)
+
+ return x, mask
+
+
+class ConformerEncoderLayer(nn.Module):
+ """
+ ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
+ See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+ normalize_before: whether to use layer_norm before the first block.
+
+ Examples::
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ cnn_module_kernel: int = 31,
+ normalize_before: bool = True,
+ use_conv_batchnorm: bool = False,
+ ) -> None:
+ super(ConformerEncoderLayer, self).__init__()
+ self.self_attn = RelPositionMultiheadAttention(
+ d_model, nhead, dropout=0.0
+ )
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.feed_forward_macaron = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.conv_module = ConvolutionModule(
+ d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm
+ )
+
+ self.norm_ff_macaron = nn.LayerNorm(
+ d_model
+ ) # for the macaron style FNN module
+ self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
+ self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
+
+ self.ff_scale = 0.5
+
+ self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
+ self.norm_final = nn.LayerNorm(
+ d_model
+ ) # for the final output of the block
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ pos_emb: Positional embedding tensor (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ src: (S, N, E).
+ pos_emb: (N, 2*S-1, E)
+ src_mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, N is the batch size, E is the feature number
+ """
+
+ # macaron style feed forward module
+ residual = src
+ if self.normalize_before:
+ src = self.norm_ff_macaron(src)
+ src = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(src)
+ )
+ if not self.normalize_before:
+ src = self.norm_ff_macaron(src)
+
+ # multi-headed self-attention module
+ residual = src
+ if self.normalize_before:
+ src = self.norm_mha(src)
+ src_att = self.self_attn(
+ src,
+ src,
+ src,
+ pos_emb=pos_emb,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask,
+ )[0]
+ src = residual + self.dropout(src_att)
+ if not self.normalize_before:
+ src = self.norm_mha(src)
+
+ # convolution module
+ residual = src
+ if self.normalize_before:
+ src = self.norm_conv(src)
+ src = residual + self.dropout(self.conv_module(src))
+ if not self.normalize_before:
+ src = self.norm_conv(src)
+
+ # feed forward module
+ residual = src
+ if self.normalize_before:
+ src = self.norm_ff(src)
+ src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
+ if not self.normalize_before:
+ src = self.norm_ff(src)
+
+ if self.normalize_before:
+ src = self.norm_final(src)
+
+ return src
+
+
+class ConformerEncoder(nn.TransformerEncoder):
+ r"""ConformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the ConformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ norm: the layer normalization component (optional).
+
+ Examples::
+ >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
+ >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = conformer_encoder(src, pos_emb)
+ """
+
+ def __init__(
+ self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
+ ) -> None:
+ super(ConformerEncoder, self).__init__(
+ encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
+ )
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required).
+ pos_emb: Positional embedding tensor (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ src: (S, N, E).
+ pos_emb: (N, 2*S-1, E)
+ mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
+
+ """
+ output = src
+
+ for mod in self.layers:
+ output = mod(
+ output,
+ pos_emb,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module.
+
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+ Args:
+ d_model: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length.
+
+ """
+
+ def __init__(
+ self, d_model: int, dropout_rate: float, max_len: int = 5000
+ ) -> None:
+ """Construct an PositionalEncoding object."""
+ super(RelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: Tensor) -> None:
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ # Note: TorchScript doesn't implement operator== for torch.Device
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(
+ x.device
+ ):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2
+ - x.size(1)
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ + x.size(1),
+ ]
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+ r"""Multi-Head Attention layer with relative position encoding
+
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+
+ Examples::
+
+ >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ) -> None:
+ super(RelPositionMultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ # linear transformation for positional encoding.
+ self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ nn.init.xavier_uniform_(self.in_proj.weight)
+ nn.init.constant_(self.in_proj.bias, 0.0)
+ nn.init.constant_(self.out_proj.bias, 0.0)
+
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ pos_emb: Positional embedding tensor
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. When given a binary mask and a value is True,
+ the corresponding value on the attention layer will be ignored. When given
+ a byte mask and a value is non-zero, the corresponding value on the attention
+ layer will be ignored
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shape:
+ - Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ - Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ return self.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ pos_emb,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj.weight,
+ self.in_proj.bias,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ )
+
+ def rel_shift(self, x: Tensor) -> Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x: Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ Tensor: tensor of shape (batch, head, time1, time2)
+ (note: time2 has the same value as time1, but it is for
+ the key, while time1 is for the query).
+ """
+ (batch_size, num_heads, time1, n) = x.shape
+ assert n == 2 * time1 - 1
+ # Note: TorchScript requires explicit arg for stride()
+ batch_stride = x.stride(0)
+ head_stride = x.stride(1)
+ time1_stride = x.stride(2)
+ n_stride = x.stride(3)
+ return x.as_strided(
+ (batch_size, num_heads, time1, time1),
+ (batch_stride, head_stride, time1_stride - n_stride, n_stride),
+ storage_offset=n_stride * (time1 - 1),
+ )
+
+ def multi_head_attention_forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_emb: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ pos_emb: Positional embedding tensor
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
+ length, N is the batch size, E is the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+ will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ assert (
+ head_dim * num_heads == embed_dim
+ ), "embed_dim must be divisible by num_heads"
+ scaling = float(head_dim) ** -0.5
+
+ if torch.equal(query, key) and torch.equal(key, value):
+ # self-attention
+ q, k, v = nn.functional.linear(
+ query, in_proj_weight, in_proj_bias
+ ).chunk(3, dim=-1)
+
+ elif torch.equal(key, value):
+ # encoder-decoder attention
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = nn.functional.linear(query, _w, _b)
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
+
+ else:
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = nn.functional.linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = nn.functional.linear(key, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = nn.functional.linear(value, _w, _b)
+
+ if attn_mask is not None:
+ assert (
+ attn_mask.dtype == torch.float32
+ or attn_mask.dtype == torch.float64
+ or attn_mask.dtype == torch.float16
+ or attn_mask.dtype == torch.uint8
+ or attn_mask.dtype == torch.bool
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
+ attn_mask.dtype
+ )
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn(
+ "Byte tensor for attn_mask is deprecated. Use bool tensor instead."
+ )
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError(
+ "The size of the 2D attn_mask is not correct."
+ )
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [
+ bsz * num_heads,
+ query.size(0),
+ key.size(0),
+ ]:
+ raise RuntimeError(
+ "The size of the 3D attn_mask is not correct."
+ )
+ else:
+ raise RuntimeError(
+ "attn_mask's dimension {} is not supported".format(
+ attn_mask.dim()
+ )
+ )
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if (
+ key_padding_mask is not None
+ and key_padding_mask.dtype == torch.uint8
+ ):
+ warnings.warn(
+ "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
+ )
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
+ k = k.contiguous().view(-1, bsz, num_heads, head_dim)
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+ src_len = k.size(0)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz, "{} == {}".format(
+ key_padding_mask.size(0), bsz
+ )
+ assert key_padding_mask.size(1) == src_len, "{} == {}".format(
+ key_padding_mask.size(1), src_len
+ )
+
+ q = q.transpose(0, 1) # (batch, time1, head, d_k)
+
+ pos_emb_bsz = pos_emb.size(0)
+ assert pos_emb_bsz in (1, bsz) # actually it is 1
+ p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
+
+ q_with_bias_u = (q + self.pos_bias_u).transpose(
+ 1, 2
+ ) # (batch, head, time1, d_k)
+
+ q_with_bias_v = (q + self.pos_bias_v).transpose(
+ 1, 2
+ ) # (batch, head, time1, d_k)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
+ matrix_ac = torch.matmul(
+ q_with_bias_u, k
+ ) # (batch, head, time1, time2)
+
+ # compute matrix b and matrix d
+ matrix_bd = torch.matmul(
+ q_with_bias_v, p.transpose(-2, -1)
+ ) # (batch, head, time1, 2*time1-1)
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ attn_output_weights = (
+ matrix_ac + matrix_bd
+ ) * scaling # (batch, head, time1, time2)
+
+ attn_output_weights = attn_output_weights.view(
+ bsz * num_heads, tgt_len, -1
+ )
+
+ assert list(attn_output_weights.size()) == [
+ bsz * num_heads,
+ tgt_len,
+ src_len,
+ ]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
+ else:
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(
+ bsz, num_heads, tgt_len, src_len
+ )
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(
+ bsz * num_heads, tgt_len, src_len
+ )
+
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = nn.functional.dropout(
+ attn_output_weights, p=dropout_p, training=training
+ )
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = (
+ attn_output.transpose(0, 1)
+ .contiguous()
+ .view(tgt_len, bsz, embed_dim)
+ )
+ attn_output = nn.functional.linear(
+ attn_output, out_proj_weight, out_proj_bias
+ )
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(
+ bsz, num_heads, tgt_len, src_len
+ )
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ bias: bool = True,
+ use_batchnorm: bool = False,
+ ) -> None:
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+ self.use_batchnorm = use_batchnorm
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ if self.use_batchnorm:
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = Swish()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_batchnorm:
+ x = self.norm(x)
+ x = self.activation(x)
+
+ x = self.pointwise_conv2(x) # (batch, channel, time)
+
+ return x.permute(2, 0, 1)
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+def identity(x):
+ return x
diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py
new file mode 100755
index 000000000..6ab9852b4
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py
@@ -0,0 +1,715 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
+# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from conformer import Conformer
+from gigaspeech_scoring import asr_text_post_processing
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.decode import (
+ get_lattice,
+ nbest_decoding,
+ nbest_oracle,
+ one_best_decoding,
+ rescore_with_attention_decoder,
+ rescore_with_n_best_list,
+ rescore_with_whole_lattice,
+)
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ get_texts,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=0,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="attention-decoder",
+ help="""Decoding method.
+ Supported values are:
+ - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
+ model, i.e., lang_dir/bpe.model, to convert word pieces to words.
+ It needs neither a lexicon nor an n-gram LM.
+ - (1) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ - (2) nbest. Extract n paths from the decoding lattice; the path
+ with the highest score is the decoding result.
+ - (3) nbest-rescoring. Extract n paths from the decoding lattice,
+ rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
+ the highest score is the decoding result.
+ - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
+ n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
+ is the decoding result.
+ - (5) attention-decoder. Extract n paths from the LM rescored
+ lattice, the path with the highest score is the decoding result.
+ - (6) nbest-oracle. Its WER is the lower bound of any n-best
+ rescoring method can achieve. Useful for debugging n-best
+ rescoring method.
+ """,
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=1000,
+ help="""Number of paths for n-best based decoding method.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""The scale to be applied to `lattice.scores`.
+ It's needed if you use any kinds of n-best based rescoring.
+ Used only when "method" is one of the following values:
+ nbest, nbest-rescoring, attention-decoder, and nbest-oracle
+ A smaller value results in more unique paths.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_bpe_500",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--lm-dir",
+ type=str,
+ default="data/lm",
+ help="""The LM dir.
+ It should contain either G_4_gram.pt or G_4_gram.fst.txt
+ """,
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "vgg_frontend": False,
+ "use_feat_batchnorm": True,
+ "feature_dim": 80,
+ "nhead": 8,
+ "attention_dim": 512,
+ "num_decoder_layers": 6,
+ # parameters for decoding
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def post_processing(
+ results: List[Tuple[List[str], List[str]]],
+) -> List[Tuple[List[str], List[str]]]:
+ new_results = []
+ for ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: Optional[k2.Fsa],
+ H: Optional[k2.Fsa],
+ bpe_model: Optional[spm.SentencePieceProcessor],
+ batch: dict,
+ word_table: k2.SymbolTable,
+ sos_id: int,
+ eos_id: int,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if no rescoring is used, the key is the string `no_rescore`.
+ If LM rescoring is used, the key is the string `lm_scale_xxx`,
+ where `xxx` is the value of `lm_scale`. An example key is
+ `lm_scale_0.7`
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding without LM rescoring.
+ - params.method is "nbest", it uses nbest decoding without LM rescoring.
+ - params.method is "nbest-rescoring", it uses nbest LM rescoring.
+ - params.method is "whole-lattice-rescoring", it uses whole lattice LM
+ rescoring.
+
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ sos_id:
+ The token ID of the SOS.
+ eos_id:
+ The token ID of the EOS.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict. Note: If it decodes to nothing, then return None.
+ """
+ if HLG is not None:
+ device = HLG.device
+ else:
+ device = H.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+
+ nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+ # nnet_output is (N, T, C)
+
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ supervisions["start_frame"] // params.subsampling_factor,
+ supervisions["num_frames"] // params.subsampling_factor,
+ ),
+ 1,
+ ).to(torch.int32)
+
+ if H is None:
+ assert HLG is not None
+ decoding_graph = HLG
+ else:
+ assert HLG is None
+ assert bpe_model is not None
+ decoding_graph = H
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ decoding_graph=decoding_graph,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ subsampling_factor=params.subsampling_factor,
+ )
+
+ if params.method == "ctc-decoding":
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ # Note: `best_path.aux_labels` contains token IDs, not word IDs
+ # since we are using H, not HLG here.
+ #
+ # token_ids is a lit-of-list of IDs
+ token_ids = get_texts(best_path)
+
+ # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
+ hyps = bpe_model.decode(token_ids)
+
+ # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
+ hyps = [s.split() for s in hyps]
+ key = "ctc-decoding"
+ return {key: hyps}
+
+ if params.method == "nbest-oracle":
+ # Note: You can also pass rescored lattices to it.
+ # We choose the HLG decoded lattice for speed reasons
+ # as HLG decoding is faster and the oracle WER
+ # is only slightly worse than that of rescored lattices.
+ best_path = nbest_oracle(
+ lattice=lattice,
+ num_paths=params.num_paths,
+ ref_texts=supervisions["text"],
+ word_table=word_table,
+ nbest_scale=params.nbest_scale,
+ oov="",
+ )
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
+ return {key: hyps}
+
+ if params.method in ["1best", "nbest"]:
+ if params.method == "1best":
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ key = "no_rescore"
+ else:
+ best_path = nbest_decoding(
+ lattice=lattice,
+ num_paths=params.num_paths,
+ use_double_scores=params.use_double_scores,
+ nbest_scale=params.nbest_scale,
+ )
+ key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
+
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ return {key: hyps}
+
+ assert params.method in [
+ "nbest-rescoring",
+ "whole-lattice-rescoring",
+ "attention-decoder",
+ ]
+
+ lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
+ lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
+ lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
+
+ if params.method == "nbest-rescoring":
+ best_path_dict = rescore_with_n_best_list(
+ lattice=lattice,
+ G=G,
+ num_paths=params.num_paths,
+ lm_scale_list=lm_scale_list,
+ nbest_scale=params.nbest_scale,
+ )
+ elif params.method == "whole-lattice-rescoring":
+ best_path_dict = rescore_with_whole_lattice(
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.method == "attention-decoder":
+ # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
+ rescored_lattice = rescore_with_whole_lattice(
+ lattice=lattice,
+ G_with_epsilon_loops=G,
+ lm_scale_list=None,
+ )
+ # TODO: pass `lattice` instead of `rescored_lattice` to
+ # `rescore_with_attention_decoder`
+
+ best_path_dict = rescore_with_attention_decoder(
+ lattice=rescored_lattice,
+ num_paths=params.num_paths,
+ model=model,
+ memory=memory,
+ memory_key_padding_mask=memory_key_padding_mask,
+ sos_id=sos_id,
+ eos_id=eos_id,
+ nbest_scale=params.nbest_scale,
+ )
+ else:
+ assert False, f"Unsupported decoding method: {params.method}"
+
+ ans = dict()
+ if best_path_dict is not None:
+ for lm_scale_str, best_path in best_path_dict.items():
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+ ans[lm_scale_str] = hyps
+ else:
+ ans = None
+ return ans
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: Optional[k2.Fsa],
+ H: Optional[k2.Fsa],
+ bpe_model: Optional[spm.SentencePieceProcessor],
+ word_table: k2.SymbolTable,
+ sos_id: int,
+ eos_id: int,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ word_table:
+ It is the word symbol table.
+ sos_id:
+ The token ID for SOS.
+ eos_id:
+ The token ID for EOS.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return a dict, whose key may be "no-rescore" if no LM rescoring
+ is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ HLG=HLG,
+ H=H,
+ bpe_model=bpe_model,
+ batch=batch,
+ word_table=word_table,
+ G=G,
+ sos_id=sos_id,
+ eos_id=eos_id,
+ )
+
+ if hyps_dict is not None:
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+ else:
+ assert (
+ len(results) > 0
+ ), "It should not decode to empty in the first batch!"
+ this_batch = []
+ hyp_words = []
+ for ref_text in texts:
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ for lm_scale in results.keys():
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ if params.method == "attention-decoder":
+ # Set it to False since there are too many logs.
+ enable_log = False
+ else:
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+ results = post_processing(results)
+ store_transcripts(filename=recog_path, texts=results)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=enable_log
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info(
+ "Wrote detailed error stats to {}".format(errs_filename)
+ )
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+ args.lm_dir = Path(args.lm_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ lexicon = Lexicon(params.lang_dir)
+ max_token_id = max(lexicon.tokens)
+ num_classes = max_token_id + 1 # +1 for the blank
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ graph_compiler = BpeCtcTrainingGraphCompiler(
+ params.lang_dir,
+ device=device,
+ sos_token="",
+ eos_token="",
+ )
+ sos_id = graph_compiler.sos_id
+ eos_id = graph_compiler.eos_id
+
+ if params.method == "ctc-decoding":
+ HLG = None
+ H = k2.ctc_topo(
+ max_token=max_token_id,
+ modified=False,
+ device=device,
+ )
+ bpe_model = spm.SentencePieceProcessor()
+ bpe_model.load(str(params.lang_dir / "bpe.model"))
+ else:
+ H = None
+ bpe_model = None
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ )
+ assert HLG.requires_grad is False
+
+ if not hasattr(HLG, "lm_scores"):
+ HLG.lm_scores = HLG.scores.clone()
+
+ if params.method in (
+ "nbest-rescoring",
+ "whole-lattice-rescoring",
+ "attention-decoder",
+ ):
+ if not (params.lm_dir / "G_4_gram.pt").is_file():
+ logging.info("Loading G_4_gram.fst.txt")
+ logging.warning("It may take 8 minutes.")
+ with open(params.lm_dir / "G_4_gram.fst.txt") as f:
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+ # G.aux_labels is not needed in later computations, so
+ # remove it here.
+ del G.aux_labels
+ # CAUTION: The following line is crucial.
+ # Arcs entering the back-off state have label equal to #0.
+ # We have to change it to 0 here.
+ G.labels[G.labels >= first_word_disambig_id] = 0
+ # See https://github.com/k2-fsa/k2/issues/874
+ # for why we need to set G.properties to None
+ G.__dict__["_properties"] = None
+ G = k2.Fsa.from_fsas([G]).to(device)
+ G = k2.arc_sort(G)
+ # Save a dummy value so that it can be loaded in C++.
+ # See https://github.com/pytorch/pytorch/issues/67902
+ # for why we need to do this.
+ G.dummy = 1
+
+ torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
+ else:
+ logging.info("Loading pre-compiled G_4_gram.pt")
+ d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ G = k2.Fsa.from_dict(d)
+
+ if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
+ # Add epsilon self-loops to G as we will compose
+ # it with the whole lattice later
+ G = k2.add_epsilon_self_loops(G)
+ G = k2.arc_sort(G)
+ G = G.to(device)
+
+ # G.lm_scores is used to replace HLG.lm_scores during
+ # LM rescoring.
+ G.lm_scores = G.scores.clone()
+ else:
+ G = None
+
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.attention_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_decoder_layers=params.num_decoder_layers,
+ vgg_frontend=params.vgg_frontend,
+ use_feat_batchnorm=params.use_feat_batchnorm,
+ )
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ dev_cuts = gigaspeech.dev_cuts()
+ test_cuts = gigaspeech.test_cuts()
+
+ dev_dl = gigaspeech.test_dataloaders(dev_cuts)
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ HLG=HLG,
+ H=H,
+ bpe_model=bpe_model,
+ word_table=lexicon.word_table,
+ G=G,
+ sos_id=sos_id,
+ eos_id=eos_id,
+ )
+
+ save_results(
+ params=params, test_set_name=test_set, results_dict=results_dict
+ )
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
new file mode 100755
index 000000000..ef53b77f8
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python3
+# Copyright 2021 Jiayu Du
+# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import os
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="This script evaluates GigaSpeech ASR result via"
+ "SCTK's tool sclite"
+ )
+ parser.add_argument(
+ "ref",
+ type=str,
+ help="sclite's standard transcription(trn) reference file",
+ )
+ parser.add_argument(
+ "hyp",
+ type=str,
+ help="sclite's standard transcription(trn) hypothesis file",
+ )
+ parser.add_argument(
+ "work_dir",
+ type=str,
+ help="working dir",
+ )
+ args = parser.parse_args()
+
+ if not os.path.isdir(args.work_dir):
+ os.mkdir(args.work_dir)
+
+ REF = os.path.join(args.work_dir, "REF")
+ HYP = os.path.join(args.work_dir, "HYP")
+ RESULT = os.path.join(args.work_dir, "RESULT")
+
+ for io in [(args.ref, REF), (args.hyp, HYP)]:
+ with open(io[0], "r", encoding="utf8") as fi:
+ with open(io[1], "w+", encoding="utf8") as fo:
+ for line in fi:
+ line = line.strip()
+ if line:
+ cols = line.split()
+ text = asr_text_post_processing(" ".join(cols[0:-1]))
+ uttid_field = cols[-1]
+ print(f"{text} {uttid_field}", file=fo)
+
+ # GigaSpeech's uttid comforms to swb
+ os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}")
diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py
new file mode 100644
index 000000000..cdc85ce9a
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py
@@ -0,0 +1,98 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+class LabelSmoothingLoss(torch.nn.Module):
+ """
+ Implement the LabelSmoothingLoss proposed in the following paper
+ https://arxiv.org/pdf/1512.00567.pdf
+ (Rethinking the Inception Architecture for Computer Vision)
+
+ """
+
+ def __init__(
+ self,
+ ignore_index: int = -1,
+ label_smoothing: float = 0.1,
+ reduction: str = "sum",
+ ) -> None:
+ """
+ Args:
+ ignore_index:
+ ignored class id
+ label_smoothing:
+ smoothing rate (0.0 means the conventional cross entropy loss)
+ reduction:
+ It has the same meaning as the reduction in
+ `torch.nn.CrossEntropyLoss`. It can be one of the following three
+ values: (1) "none": No reduction will be applied. (2) "mean": the
+ mean of the output is taken. (3) "sum": the output will be summed.
+ """
+ super().__init__()
+ assert 0.0 <= label_smoothing < 1.0
+ self.ignore_index = ignore_index
+ self.label_smoothing = label_smoothing
+ self.reduction = reduction
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """
+ Compute loss between x and target.
+
+ Args:
+ x:
+ prediction of dimension
+ (batch_size, input_length, number_of_classes).
+ target:
+ target masked with self.ignore_index of
+ dimension (batch_size, input_length).
+
+ Returns:
+ A scalar tensor containing the loss without normalization.
+ """
+ assert x.ndim == 3
+ assert target.ndim == 2
+ assert x.shape[:2] == target.shape
+ num_classes = x.size(-1)
+ x = x.reshape(-1, num_classes)
+ # Now x is of shape (N*T, C)
+
+ # We don't want to change target in-place below,
+ # so we make a copy of it here
+ target = target.clone().reshape(-1)
+
+ ignored = target == self.ignore_index
+ target[ignored] = 0
+
+ true_dist = torch.nn.functional.one_hot(
+ target, num_classes=num_classes
+ ).to(x)
+
+ true_dist = (
+ true_dist * (1 - self.label_smoothing)
+ + self.label_smoothing / num_classes
+ )
+ # Set the value of ignored indexes to 0
+ true_dist[ignored] = 0
+
+ loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
+ if self.reduction == "sum":
+ return loss.sum()
+ elif self.reduction == "mean":
+ return loss.sum() / (~ignored).sum()
+ else:
+ return loss.sum(dim=-1)
diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py
new file mode 100644
index 000000000..542fb0364
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py
@@ -0,0 +1,161 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+
+
+class Conv2dSubsampling(nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
+ T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
+
+ It is based on
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
+ """
+
+ def __init__(self, idim: int, odim: int) -> None:
+ """
+ Args:
+ idim:
+ Input dim. The input shape is (N, T, idim).
+ Caution: It requires: T >=7, idim >=7
+ odim:
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
+ """
+ assert idim >= 7
+ super().__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=1, out_channels=odim, kernel_size=3, stride=2
+ ),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels=odim, out_channels=odim, kernel_size=3, stride=2
+ ),
+ nn.ReLU(),
+ )
+ self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Subsample x.
+
+ Args:
+ x:
+ Its shape is (N, T, idim).
+
+ Returns:
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
+ """
+ # On entry, x is (N, T, idim)
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
+ x = self.conv(x)
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
+ return x
+
+
+class VggSubsampling(nn.Module):
+ """Trying to follow the setup described in the following paper:
+ https://arxiv.org/pdf/1910.09799.pdf
+
+ This paper is not 100% explicit so I am guessing to some extent,
+ and trying to compare with other VGG implementations.
+
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
+ T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
+ """
+
+ def __init__(self, idim: int, odim: int) -> None:
+ """Construct a VggSubsampling object.
+
+ This uses 2 VGG blocks with 2 Conv2d layers each,
+ subsampling its input by a factor of 4 in the time dimensions.
+
+ Args:
+ idim:
+ Input dim. The input shape is (N, T, idim).
+ Caution: It requires: T >=7, idim >=7
+ odim:
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
+ """
+ super().__init__()
+
+ cur_channels = 1
+ layers = []
+ block_dims = [32, 64]
+
+ # The decision to use padding=1 for the 1st convolution, then padding=0
+ # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
+ # a back-compatibility concern so that the number of frames at the
+ # output would be equal to:
+ # (((T-1)//2)-1)//2.
+ # We can consider changing this by using padding=1 on the
+ # 2nd convolution, so the num-frames at the output would be T//4.
+ for block_dim in block_dims:
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=cur_channels,
+ out_channels=block_dim,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ )
+ )
+ layers.append(torch.nn.ReLU())
+ layers.append(
+ torch.nn.Conv2d(
+ in_channels=block_dim,
+ out_channels=block_dim,
+ kernel_size=3,
+ padding=0,
+ stride=1,
+ )
+ )
+ layers.append(
+ torch.nn.MaxPool2d(
+ kernel_size=2, stride=2, padding=0, ceil_mode=True
+ )
+ )
+ cur_channels = block_dim
+
+ self.layers = nn.Sequential(*layers)
+
+ self.out = nn.Linear(
+ block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Subsample x.
+
+ Args:
+ x:
+ Its shape is (N, T, idim).
+
+ Returns:
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
+ """
+ x = x.unsqueeze(1)
+ x = self.layers(x)
+ b, c, t, f = x.size()
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+ return x
diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py
new file mode 100755
index 000000000..2965cde18
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/train.py
@@ -0,0 +1,737 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.tensorboard import SummaryWriter
+from transformer import Noam
+
+from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ encode_supervisions,
+ setup_logger,
+ str2bool,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=20,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ conformer_ctc/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_bpe_500",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--att-rate",
+ type=float,
+ default=0.7,
+ help="""The attention rate.
+ The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-factor",
+ type=float,
+ default=5.0,
+ help="The lr_factor for Noam optimizer",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - use_feat_batchnorm: Normalization for the input features, can be a
+ boolean indicating whether to do batch
+ normalization, or a float which means just scaling
+ the input features with this float value.
+ If given a float value, we will remove batchnorm
+ layer in `ConvolutionModule` as well.
+
+ - attention_dim: Hidden dim for multi-head attention model.
+
+ - head: Number of heads of multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - beam_size: It is used in k2.ctc_loss
+
+ - reduction: It is used in k2.ctc_loss
+
+ - use_double_scores: It is used in k2.ctc_loss
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 500,
+ "reset_interval": 2000,
+ "valid_interval": 30000,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "use_feat_batchnorm": True,
+ "attention_dim": 512,
+ "nhead": 8,
+ "num_decoder_layers": 6,
+ # parameters for loss
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ # parameters for Noam
+ "weight_decay": 1e-6,
+ "warm_step": 100000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ batch: dict,
+ graph_compiler: BpeCtcTrainingGraphCompiler,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ graph_compiler:
+ It is used to build a decoding graph from a ctc topo and training
+ transcript. The training transcript is contained in the given `batch`,
+ while the ctc topo is built when this compiler is instantiated.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = graph_compiler.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ with torch.set_grad_enabled(is_training):
+ nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
+ # nnet_output is (N, T, C)
+
+ # NOTE: We need `encode_supervisions` to sort sequences with
+ # different duration in decreasing order, required by
+ # `k2.intersect_dense` called in `k2.ctc_loss`
+ supervision_segments, texts = encode_supervisions(
+ supervisions, subsampling_factor=params.subsampling_factor
+ )
+
+ token_ids = graph_compiler.texts_to_ids(texts)
+
+ decoding_graph = graph_compiler.compile(token_ids)
+
+ dense_fsa_vec = k2.DenseFsaVec(
+ nnet_output,
+ supervision_segments,
+ allow_truncate=params.subsampling_factor - 1,
+ )
+
+ ctc_loss = k2.ctc_loss(
+ decoding_graph=decoding_graph,
+ dense_fsa_vec=dense_fsa_vec,
+ output_beam=params.beam_size,
+ reduction=params.reduction,
+ use_double_scores=params.use_double_scores,
+ )
+
+ if params.att_rate != 0.0:
+ with torch.set_grad_enabled(is_training):
+ mmodel = model.module if hasattr(model, "module") else model
+ # Note: We need to generate an unsorted version of token_ids
+ # `encode_supervisions()` called above sorts text, but
+ # encoder_memory and memory_mask are not sorted, so we
+ # use an unsorted version `supervisions["text"]` to regenerate
+ # the token_ids
+ #
+ # See https://github.com/k2-fsa/icefall/issues/97
+ # for more details
+ unsorted_token_ids = graph_compiler.texts_to_ids(
+ supervisions["text"]
+ )
+ att_loss = mmodel.decoder_forward(
+ encoder_memory,
+ memory_mask,
+ token_ids=unsorted_token_ids,
+ sos_id=graph_compiler.sos_id,
+ eos_id=graph_compiler.eos_id,
+ )
+ loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
+ else:
+ loss = ctc_loss
+ att_loss = torch.tensor([0])
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ info["frames"] = supervision_segments[:, 2].sum().item()
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+ if params.att_rate != 0.0:
+ info["att_loss"] = att_loss.detach().cpu().item()
+
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ graph_compiler: BpeCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: BpeCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ graph_compiler:
+ It is used to convert transcripts to FSAs.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+
+ if tb_writer is not None:
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(42)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ lexicon = Lexicon(params.lang_dir)
+ max_token_id = max(lexicon.tokens)
+ num_classes = max_token_id + 1 # +1 for the blank
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+
+ graph_compiler = BpeCtcTrainingGraphCompiler(
+ params.lang_dir,
+ device=device,
+ sos_token="",
+ eos_token="",
+ )
+
+ logging.info("About to create model")
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.attention_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_decoder_layers=params.num_decoder_layers,
+ vgg_frontend=False,
+ use_feat_batchnorm=params.use_feat_batchnorm,
+ )
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = Noam(
+ model.parameters(),
+ model_size=params.attention_dim,
+ factor=params.lr_factor,
+ warm_step=params.warm_step,
+ weight_decay=params.weight_decay,
+ )
+
+ if checkpoints:
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ GigaSpeech = GigaSpeechAsrDataModule(args)
+
+ train_cuts = GigaSpeech.train_cuts()
+ train_dl = GigaSpeech.train_dataloaders(train_cuts)
+
+ valid_cuts = GigaSpeech.dev_cuts()
+ valid_dl = GigaSpeech.valid_dataloaders(valid_cuts)
+
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ train_dl.sampler.set_epoch(epoch)
+
+ cur_lr = optimizer._rate
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ if rank == 0:
+ logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: BpeCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ optimizer.zero_grad()
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ )
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py
new file mode 100644
index 000000000..00ca027a7
--- /dev/null
+++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py
@@ -0,0 +1,953 @@
+# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from label_smoothing import LabelSmoothingLoss
+from subsampling import Conv2dSubsampling, VggSubsampling
+from torch.nn.utils.rnn import pad_sequence
+
+# Note: TorchScript requires Dict/List/etc. to be fully typed.
+Supervisions = Dict[str, torch.Tensor]
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ num_classes: int,
+ subsampling_factor: int = 4,
+ d_model: int = 256,
+ nhead: int = 4,
+ dim_feedforward: int = 2048,
+ num_encoder_layers: int = 12,
+ num_decoder_layers: int = 6,
+ dropout: float = 0.1,
+ normalize_before: bool = True,
+ vgg_frontend: bool = False,
+ use_feat_batchnorm: Union[float, bool] = 0.1,
+ ) -> None:
+ """
+ Args:
+ num_features:
+ The input dimension of the model.
+ num_classes:
+ The output dimension of the model.
+ subsampling_factor:
+ Number of output frames is num_in_frames // subsampling_factor.
+ Currently, subsampling_factor MUST be 4.
+ d_model:
+ Attention dimension.
+ nhead:
+ Number of heads in multi-head attention.
+ Must satisfy d_model // nhead == 0.
+ dim_feedforward:
+ The output dimension of the feedforward layers in encoder/decoder.
+ num_encoder_layers:
+ Number of encoder layers.
+ num_decoder_layers:
+ Number of decoder layers.
+ dropout:
+ Dropout in encoder/decoder.
+ normalize_before:
+ If True, use pre-layer norm; False to use post-layer norm.
+ vgg_frontend:
+ True to use vgg style frontend for subsampling.
+ use_feat_batchnorm:
+ True to use batchnorm for the input layer.
+ Float value to scale the input layer.
+ False to do nothing.
+ """
+ super().__init__()
+ self.use_feat_batchnorm = use_feat_batchnorm
+ assert isinstance(use_feat_batchnorm, (float, bool))
+ if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm:
+ self.feat_batchnorm = nn.BatchNorm1d(num_features)
+
+ self.num_features = num_features
+ self.num_classes = num_classes
+ self.subsampling_factor = subsampling_factor
+ if subsampling_factor != 4:
+ raise NotImplementedError("Support only 'subsampling_factor=4'.")
+
+ # self.encoder_embed converts the input of shape (N, T, num_classes)
+ # to the shape (N, T//subsampling_factor, d_model).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> T//subsampling_factor
+ # (2) embedding: num_classes -> d_model
+ if vgg_frontend:
+ self.encoder_embed = VggSubsampling(num_features, d_model)
+ else:
+ self.encoder_embed = Conv2dSubsampling(num_features, d_model)
+
+ self.encoder_pos = PositionalEncoding(d_model, dropout)
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ normalize_before=normalize_before,
+ )
+
+ if normalize_before:
+ encoder_norm = nn.LayerNorm(d_model)
+ else:
+ encoder_norm = None
+
+ self.encoder = nn.TransformerEncoder(
+ encoder_layer=encoder_layer,
+ num_layers=num_encoder_layers,
+ norm=encoder_norm,
+ )
+
+ # TODO(fangjun): remove dropout
+ self.encoder_output_layer = nn.Sequential(
+ nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
+ )
+
+ if num_decoder_layers > 0:
+ self.decoder_num_class = (
+ self.num_classes
+ ) # bpe model already has sos/eos symbol
+
+ self.decoder_embed = nn.Embedding(
+ num_embeddings=self.decoder_num_class, embedding_dim=d_model
+ )
+ self.decoder_pos = PositionalEncoding(d_model, dropout)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ normalize_before=normalize_before,
+ )
+
+ if normalize_before:
+ decoder_norm = nn.LayerNorm(d_model)
+ else:
+ decoder_norm = None
+
+ self.decoder = nn.TransformerDecoder(
+ decoder_layer=decoder_layer,
+ num_layers=num_decoder_layers,
+ norm=decoder_norm,
+ )
+
+ self.decoder_output_layer = torch.nn.Linear(
+ d_model, self.decoder_num_class
+ )
+
+ self.decoder_criterion = LabelSmoothingLoss()
+ else:
+ self.decoder_criterion = None
+
+ def forward(
+ self, x: torch.Tensor, supervision: Optional[Supervisions] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (N, T, C).
+ supervision:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ (CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling)
+
+ Returns:
+ Return a tuple containing 3 tensors:
+ - CTC output for ctc decoding. Its shape is (N, T, C)
+ - Encoder output with shape (T, N, C). It can be used as key and
+ value for the decoder.
+ - Encoder output padding mask. It can be used as
+ memory_key_padding_mask for the decoder. Its shape is (N, T).
+ It is None if `supervision` is None.
+ """
+ if (
+ isinstance(self.use_feat_batchnorm, bool)
+ and self.use_feat_batchnorm
+ ):
+ x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
+ x = self.feat_batchnorm(x)
+ x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
+ if isinstance(self.use_feat_batchnorm, float):
+ x *= self.use_feat_batchnorm
+ encoder_memory, memory_key_padding_mask = self.run_encoder(
+ x, supervision
+ )
+ x = self.ctc_output(encoder_memory)
+ return x, encoder_memory, memory_key_padding_mask
+
+ def run_encoder(
+ self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Run the transformer encoder.
+
+ Args:
+ x:
+ The model input. Its shape is (N, T, C).
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling
+ It is read directly from the batch, without any sorting. It is used
+ to compute the encoder padding mask, which is used as memory key
+ padding mask for the decoder.
+ Returns:
+ Return a tuple with two tensors:
+ - The encoder output, with shape (T, N, C)
+ - encoder padding mask, with shape (N, T).
+ The mask is None if `supervisions` is None.
+ It is used as memory key padding mask in the decoder.
+ """
+ x = self.encoder_embed(x)
+ x = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ mask = encoder_padding_mask(x.size(0), supervisions)
+ mask = mask.to(x.device) if mask is not None else None
+ x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
+
+ return x, mask
+
+ def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ The output tensor from the transformer encoder.
+ Its shape is (T, N, C)
+
+ Returns:
+ Return a tensor that can be used for CTC decoding.
+ Its shape is (N, T, C)
+ """
+ x = self.encoder_output_layer(x)
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
+ return x
+
+ @torch.jit.export
+ def decoder_forward(
+ self,
+ memory: torch.Tensor,
+ memory_key_padding_mask: torch.Tensor,
+ token_ids: List[List[int]],
+ sos_id: int,
+ eos_id: int,
+ ) -> torch.Tensor:
+ """
+ Args:
+ memory:
+ It's the output of the encoder with shape (T, N, C)
+ memory_key_padding_mask:
+ The padding mask from the encoder.
+ token_ids:
+ A list-of-list IDs. Each sublist contains IDs for an utterance.
+ The IDs can be either phone IDs or word piece IDs.
+ sos_id:
+ sos token id
+ eos_id:
+ eos token id
+
+ Returns:
+ A scalar, the **sum** of label smoothing loss over utterances
+ in the batch without any normalization.
+ """
+ ys_in = add_sos(token_ids, sos_id=sos_id)
+ ys_in = [torch.tensor(y) for y in ys_in]
+ ys_in_pad = pad_sequence(
+ ys_in, batch_first=True, padding_value=float(eos_id)
+ )
+
+ ys_out = add_eos(token_ids, eos_id=eos_id)
+ ys_out = [torch.tensor(y) for y in ys_out]
+ ys_out_pad = pad_sequence(
+ ys_out, batch_first=True, padding_value=float(-1)
+ )
+
+ device = memory.device
+ ys_in_pad = ys_in_pad.to(device)
+ ys_out_pad = ys_out_pad.to(device)
+
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+ device
+ )
+
+ tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+ # TODO: Use length information to create the decoder padding mask
+ # We set the first column to False since the first column in ys_in_pad
+ # contains sos_id, which is the same as eos_id in our current setting.
+ tgt_key_padding_mask[:, 0] = False
+
+ tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
+ tgt = self.decoder_pos(tgt)
+ tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ pred_pad = self.decoder(
+ tgt=tgt,
+ memory=memory,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ ) # (T, N, C)
+ pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
+ pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
+
+ decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
+
+ return decoder_loss
+
+ @torch.jit.export
+ def decoder_nll(
+ self,
+ memory: torch.Tensor,
+ memory_key_padding_mask: torch.Tensor,
+ token_ids: List[torch.Tensor],
+ sos_id: int,
+ eos_id: int,
+ ) -> torch.Tensor:
+ """
+ Args:
+ memory:
+ It's the output of the encoder with shape (T, N, C)
+ memory_key_padding_mask:
+ The padding mask from the encoder.
+ token_ids:
+ A list-of-list IDs (e.g., word piece IDs).
+ Each sublist represents an utterance.
+ sos_id:
+ The token ID for SOS.
+ eos_id:
+ The token ID for EOS.
+ Returns:
+ A 2-D tensor of shape (len(token_ids), max_token_length)
+ representing the cross entropy loss (i.e., negative log-likelihood).
+ """
+ # The common part between this function and decoder_forward could be
+ # extracted as a separate function.
+ if isinstance(token_ids[0], torch.Tensor):
+ # This branch is executed by torchscript in C++.
+ # See https://github.com/k2-fsa/k2/pull/870
+ # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
+ token_ids = [tolist(t) for t in token_ids]
+
+ ys_in = add_sos(token_ids, sos_id=sos_id)
+ ys_in = [torch.tensor(y) for y in ys_in]
+ ys_in_pad = pad_sequence(
+ ys_in, batch_first=True, padding_value=float(eos_id)
+ )
+
+ ys_out = add_eos(token_ids, eos_id=eos_id)
+ ys_out = [torch.tensor(y) for y in ys_out]
+ ys_out_pad = pad_sequence(
+ ys_out, batch_first=True, padding_value=float(-1)
+ )
+
+ device = memory.device
+ ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
+ ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
+
+ tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
+ device
+ )
+
+ tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
+ # TODO: Use length information to create the decoder padding mask
+ # We set the first column to False since the first column in ys_in_pad
+ # contains sos_id, which is the same as eos_id in our current setting.
+ tgt_key_padding_mask[:, 0] = False
+
+ tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
+ tgt = self.decoder_pos(tgt)
+ tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
+ pred_pad = self.decoder(
+ tgt=tgt,
+ memory=memory,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ ) # (T, B, F)
+ pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
+ pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
+ # nll: negative log-likelihood
+ nll = torch.nn.functional.cross_entropy(
+ pred_pad.view(-1, self.decoder_num_class),
+ ys_out_pad.view(-1),
+ ignore_index=-1,
+ reduction="none",
+ )
+
+ nll = nll.view(pred_pad.shape[0], -1)
+
+ return nll
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ Modified from torch.nn.TransformerEncoderLayer.
+ Add support of normalize_before,
+ i.e., use layer_norm before the first block.
+
+ Args:
+ d_model:
+ the number of expected features in the input (required).
+ nhead:
+ the number of heads in the multiheadattention models (required).
+ dim_feedforward:
+ the dimension of the feedforward network model (default=2048).
+ dropout:
+ the dropout value (default=0.1).
+ activation:
+ the activation function of intermediate layer, relu or
+ gelu (default=relu).
+ normalize_before:
+ whether to use layer_norm before the first block.
+
+ Examples::
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = encoder_layer(src)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ normalize_before: bool = True,
+ ) -> None:
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ self.normalize_before = normalize_before
+
+ def __setstate__(self, state):
+ if "activation" not in state:
+ state["activation"] = nn.functional.relu
+ super(TransformerEncoderLayer, self).__setstate__(state)
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None,
+ src_key_padding_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional)
+
+ Shape:
+ src: (S, N, E).
+ src_mask: (S, S).
+ src_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length,
+ N is the batch size, E is the feature number
+ """
+ residual = src
+ if self.normalize_before:
+ src = self.norm1(src)
+ src2 = self.self_attn(
+ src,
+ src,
+ src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask,
+ )[0]
+ src = residual + self.dropout1(src2)
+ if not self.normalize_before:
+ src = self.norm1(src)
+
+ residual = src
+ if self.normalize_before:
+ src = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = residual + self.dropout2(src2)
+ if not self.normalize_before:
+ src = self.norm2(src)
+ return src
+
+
+class TransformerDecoderLayer(nn.Module):
+ """
+ Modified from torch.nn.TransformerDecoderLayer.
+ Add support of normalize_before,
+ i.e., use layer_norm before the first block.
+
+ Args:
+ d_model:
+ the number of expected features in the input (required).
+ nhead:
+ the number of heads in the multiheadattention models (required).
+ dim_feedforward:
+ the dimension of the feedforward network model (default=2048).
+ dropout:
+ the dropout value (default=0.1).
+ activation:
+ the activation function of intermediate layer, relu or
+ gelu (default=relu).
+
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(20, 32, 512)
+ >>> out = decoder_layer(tgt, memory)
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ normalize_before: bool = True,
+ ) -> None:
+ super(TransformerDecoderLayer, self).__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+ self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ self.normalize_before = normalize_before
+
+ def __setstate__(self, state):
+ if "activation" not in state:
+ state["activation"] = nn.functional.relu
+ super(TransformerDecoderLayer, self).__setstate__(state)
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[torch.Tensor] = None,
+ memory_mask: Optional[torch.Tensor] = None,
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt:
+ the sequence to the decoder layer (required).
+ memory:
+ the sequence from the last layer of the encoder (required).
+ tgt_mask:
+ the mask for the tgt sequence (optional).
+ memory_mask:
+ the mask for the memory sequence (optional).
+ tgt_key_padding_mask:
+ the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask:
+ the mask for the memory keys per batch (optional).
+
+ Shape:
+ tgt: (T, N, E).
+ memory: (S, N, E).
+ tgt_mask: (T, T).
+ memory_mask: (T, S).
+ tgt_key_padding_mask: (N, T).
+ memory_key_padding_mask: (N, S).
+ S is the source sequence length, T is the target sequence length,
+ N is the batch size, E is the feature number
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+ tgt2 = self.self_attn(
+ tgt,
+ tgt,
+ tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask,
+ )[0]
+ tgt = residual + self.dropout1(tgt2)
+ if not self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm2(tgt)
+ tgt2 = self.src_attn(
+ tgt,
+ memory,
+ memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = residual + self.dropout2(tgt2)
+ if not self.normalize_before:
+ tgt = self.norm2(tgt)
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = residual + self.dropout3(tgt2)
+ if not self.normalize_before:
+ tgt = self.norm3(tgt)
+ return tgt
+
+
+def _get_activation_fn(activation: str):
+ if activation == "relu":
+ return nn.functional.relu
+ elif activation == "gelu":
+ return nn.functional.gelu
+
+ raise RuntimeError(
+ "activation should be relu/gelu, not {}".format(activation)
+ )
+
+
+class PositionalEncoding(nn.Module):
+ """This class implements the positional encoding
+ proposed in the following paper:
+
+ - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
+
+ PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
+ PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
+
+ Note::
+
+ 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
+ = exp(-1* 2i / d_model * log(100000))
+ = exp(2i * -(log(10000) / d_model))
+ """
+
+ def __init__(self, d_model: int, dropout: float = 0.1) -> None:
+ """
+ Args:
+ d_model:
+ Embedding dimension.
+ dropout:
+ Dropout probability to be applied to the output of this module.
+ """
+ super().__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = nn.Dropout(p=dropout)
+ # not doing: self.pe = None because of errors thrown by torchscript
+ self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
+
+ def extend_pe(self, x: torch.Tensor) -> None:
+ """Extend the time t in the positional encoding if required.
+
+ The shape of `self.pe` is (1, T1, d_model). The shape of the input x
+ is (N, T, d_model). If T > T1, then we change the shape of self.pe
+ to (N, T, d_model). Otherwise, nothing is done.
+
+ Args:
+ x:
+ It is a tensor of shape (N, T, C).
+ Returns:
+ Return None.
+ """
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ # Now pe is of shape (1, T, d_model), where T is x.size(1)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Add positional encoding.
+
+ Args:
+ x:
+ Its shape is (N, T, C)
+
+ Returns:
+ Return a tensor of shape (N, T, C)
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1), :]
+ return self.dropout(x)
+
+
+class Noam(object):
+ """
+ Implements Noam optimizer.
+
+ Proposed in
+ "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
+
+ Modified from
+ https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
+
+ Args:
+ params:
+ iterable of parameters to optimize or dicts defining parameter groups
+ model_size:
+ attention dimension of the transformer model
+ factor:
+ learning rate factor
+ warm_step:
+ warmup steps
+ """
+
+ def __init__(
+ self,
+ params,
+ model_size: int = 256,
+ factor: float = 10.0,
+ warm_step: int = 25000,
+ weight_decay=0,
+ ) -> None:
+ """Construct an Noam object."""
+ self.optimizer = torch.optim.Adam(
+ params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
+ )
+ self._step = 0
+ self.warmup = warm_step
+ self.factor = factor
+ self.model_size = model_size
+ self._rate = 0
+
+ @property
+ def param_groups(self):
+ """Return param_groups."""
+ return self.optimizer.param_groups
+
+ def step(self):
+ """Update parameters and rate."""
+ self._step += 1
+ rate = self.rate()
+ for p in self.optimizer.param_groups:
+ p["lr"] = rate
+ self._rate = rate
+ self.optimizer.step()
+
+ def rate(self, step=None):
+ """Implement `lrate` above."""
+ if step is None:
+ step = self._step
+ return (
+ self.factor
+ * self.model_size ** (-0.5)
+ * min(step ** (-0.5), step * self.warmup ** (-1.5))
+ )
+
+ def zero_grad(self):
+ """Reset gradient."""
+ self.optimizer.zero_grad()
+
+ def state_dict(self):
+ """Return state_dict."""
+ return {
+ "_step": self._step,
+ "warmup": self.warmup,
+ "factor": self.factor,
+ "model_size": self.model_size,
+ "_rate": self._rate,
+ "optimizer": self.optimizer.state_dict(),
+ }
+
+ def load_state_dict(self, state_dict):
+ """Load state_dict."""
+ for key, value in state_dict.items():
+ if key == "optimizer":
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+ else:
+ setattr(self, key, value)
+
+
+def encoder_padding_mask(
+ max_len: int, supervisions: Optional[Supervisions] = None
+) -> Optional[torch.Tensor]:
+ """Make mask tensor containing indexes of padded part.
+
+ TODO::
+ This function **assumes** that the model uses
+ a subsampling factor of 4. We should remove that
+ assumption later.
+
+ Args:
+ max_len:
+ Maximum length of input features.
+ CAUTION: It is the length after subsampling.
+ supervisions:
+ Supervision in lhotse format.
+ See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
+ (CAUTION: It contains length information, i.e., start and number of
+ frames, before subsampling)
+
+ Returns:
+ Tensor: Mask tensor of dimension (batch_size, input_length),
+ True denote the masked indices.
+ """
+ if supervisions is None:
+ return None
+
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ supervisions["start_frame"],
+ supervisions["num_frames"],
+ ),
+ 1,
+ ).to(torch.int32)
+
+ lengths = [
+ 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
+ ]
+ for idx in range(supervision_segments.size(0)):
+ # Note: TorchScript doesn't allow to unpack tensors as tuples
+ sequence_idx = supervision_segments[idx, 0].item()
+ start_frame = supervision_segments[idx, 1].item()
+ num_frames = supervision_segments[idx, 2].item()
+ lengths[sequence_idx] = start_frame + num_frames
+
+ lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
+ bs = int(len(lengths))
+ seq_range = torch.arange(0, max_len, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
+ # Note: TorchScript doesn't implement Tensor.new()
+ seq_length_expand = torch.tensor(
+ lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
+ ).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ return mask
+
+
+def decoder_padding_mask(
+ ys_pad: torch.Tensor, ignore_id: int = -1
+) -> torch.Tensor:
+ """Generate a length mask for input.
+
+ The masked position are filled with True,
+ Unmasked positions are filled with False.
+
+ Args:
+ ys_pad:
+ padded tensor of dimension (batch_size, input_length).
+ ignore_id:
+ the ignored number (the padding number) in ys_pad
+
+ Returns:
+ Tensor:
+ a bool tensor of the same shape as the input tensor.
+ """
+ ys_mask = ys_pad == ignore_id
+ return ys_mask
+
+
+def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
+ """Generate a square mask for the sequence. The masked positions are
+ filled with float('-inf'). Unmasked positions are filled with float(0.0).
+ The mask can be used for masked self-attention.
+
+ For instance, if sz is 3, it returns::
+
+ tensor([[0., -inf, -inf],
+ [0., 0., -inf],
+ [0., 0., 0]])
+
+ Args:
+ sz: mask size
+
+ Returns:
+ A square mask of dimension (sz, sz)
+ """
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
+ mask = (
+ mask.float()
+ .masked_fill(mask == 0, float("-inf"))
+ .masked_fill(mask == 1, float(0.0))
+ )
+ return mask
+
+
+def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
+ """Prepend sos_id to each utterance.
+
+ Args:
+ token_ids:
+ A list-of-list of token IDs. Each sublist contains
+ token IDs (e.g., word piece IDs) of an utterance.
+ sos_id:
+ The ID of the SOS token.
+
+ Return:
+ Return a new list-of-list, where each sublist starts
+ with SOS ID.
+ """
+ return [[sos_id] + utt for utt in token_ids]
+
+
+def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
+ """Append eos_id to each utterance.
+
+ Args:
+ token_ids:
+ A list-of-list of token IDs. Each sublist contains
+ token IDs (e.g., word piece IDs) of an utterance.
+ eos_id:
+ The ID of the EOS token.
+
+ Return:
+ Return a new list-of-list, where each sublist ends
+ with EOS ID.
+ """
+ return [utt + [eos_id] for utt in token_ids]
+
+
+def tolist(t: torch.Tensor) -> List[int]:
+ """Used by jit"""
+ return torch.jit.annotate(List[int], t.tolist())
diff --git a/egs/gigaspeech/ASR/local/__init__.py b/egs/gigaspeech/ASR/local/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/gigaspeech/ASR/local/compile_hlg.py b/egs/gigaspeech/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
new file mode 100755
index 000000000..9f1039893
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
+# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+)
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_gigaspeech_dev_test():
+ in_out_dir = Path("data/fbank")
+ # number of workers in dataloader
+ num_workers = 20
+
+ # number of seconds in a batch
+ batch_duration = 600
+
+ subsets = ("DEV", "TEST")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+
+ logging.info(f"device: {device}")
+
+ for partition in subsets:
+ cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{cuts_path} exists - skipping")
+ continue
+
+ raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Computing features")
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{in_out_dir}/feats_{partition}",
+ num_workers=num_workers,
+ batch_duration=batch_duration,
+ )
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+ logging.info(f"Saved to {cuts_path}")
+
+
+def main():
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_gigaspeech_dev_test()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
new file mode 100755
index 000000000..9dd3c046d
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
+# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from datetime import datetime
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+)
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=20,
+ help="Number of dataloading workers used for reading the audio.",
+ )
+ parser.add_argument(
+ "--batch-duration",
+ type=float,
+ default=600.0,
+ help="The maximum number of audio seconds in a batch."
+ "Determines batch size dynamically.",
+ )
+
+ parser.add_argument(
+ "--num-splits",
+ type=int,
+ required=True,
+ help="The number of splits of the XL subset",
+ )
+
+ parser.add_argument(
+ "--start",
+ type=int,
+ default=0,
+ help="Process pieces starting from this number (inclusive).",
+ )
+
+ parser.add_argument(
+ "--stop",
+ type=int,
+ default=-1,
+ help="Stop processing pieces until this number (exclusive).",
+ )
+ return parser
+
+
+def compute_fbank_gigaspeech_splits(args):
+ num_splits = args.num_splits
+ output_dir = "data/fbank/XL_split"
+ output_dir = Path(output_dir)
+ assert output_dir.exists(), f"{output_dir} does not exist!"
+
+ num_digits = 8 # num_digits is fixed by lhotse split-lazy
+
+ start = args.start
+ stop = args.stop
+ if stop < start:
+ stop = num_splits
+
+ stop = min(stop, num_splits)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+ logging.info(f"device: {device}")
+
+ for i in range(start, stop):
+ idx = f"{i + 1}".zfill(num_digits)
+ logging.info(f"Processing {idx}/{num_splits}")
+
+ cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{cuts_path} exists - skipping")
+ continue
+
+ raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Computing features")
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/feats_XL_{idx}",
+ num_workers=args.num_workers,
+ batch_duration=args.batch_duration,
+ )
+
+ logging.info("About to split cuts into smaller chunks.")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+ logging.info(f"Saved to {cuts_path}")
+
+
+def main():
+ now = datetime.now()
+ date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
+
+ log_filename = "log-compute_fbank_gigaspeech_splits"
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+ log_filename = f"{log_filename}-{date_time}"
+
+ logging.basicConfig(
+ filename=log_filename,
+ format=formatter,
+ level=logging.INFO,
+ filemode="w",
+ )
+
+ console = logging.StreamHandler()
+ console.setLevel(logging.INFO)
+ console.setFormatter(logging.Formatter(formatter))
+ logging.getLogger("").addHandler(console)
+
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ compute_fbank_gigaspeech_splits(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_musan.py b/egs/gigaspeech/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py
new file mode 120000
index 000000000..2ce13fd69
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/local/prepare_lang.py b/egs/gigaspeech/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/local/prepare_lang_bpe.py b/egs/gigaspeech/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
new file mode 100755
index 000000000..0cec82ad5
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2021 Johns Hopkins University (Piotr Żelasko)
+# Copyright 2021 Xiaomi Corp. (Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+from pathlib import Path
+
+from lhotse import CutSet, SupervisionSegment
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Similar text filtering and normalization procedure as in:
+# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
+
+
+def normalize_text(
+ utt: str,
+ punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
+ whitespace_pattern=re.compile(r"\s\s+"),
+) -> str:
+ return whitespace_pattern.sub(" ", punct_pattern.sub("", utt))
+
+
+def has_no_oov(
+ sup: SupervisionSegment,
+ oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
+) -> bool:
+ return oov_pattern.search(sup.text) is None
+
+
+def preprocess_giga_speech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ output_dir.mkdir(exist_ok=True)
+
+ dataset_parts = (
+ "DEV",
+ "TEST",
+ "XL",
+ )
+
+ logging.info("Loading manifest (may take 4 minutes)")
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix="gigaspeech",
+ suffix="jsonl.gz",
+ )
+ assert manifests is not None
+
+ for partition, m in manifests.items():
+ logging.info(f"Processing {partition}")
+ raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
+ if raw_cuts_path.is_file():
+ logging.info(f"{partition} already exists - skipping")
+ continue
+
+ # Note this step makes the recipe different than LibriSpeech:
+ # We must filter out some utterances and remove punctuation
+ # to be consistent with Kaldi.
+ logging.info("Filtering OOV utterances from supervisions")
+ m["supervisions"] = m["supervisions"].filter(has_no_oov)
+ logging.info(f"Normalizing text in {partition}")
+ for sup in m["supervisions"]:
+ sup.text = normalize_text(sup.text)
+
+ # Create long-recording cut manifests.
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ # Run data augmentation that needs to be done in the
+ # time domain.
+ if partition not in ["DEV", "TEST"]:
+ logging.info(
+ f"Speed perturb for {partition} with factors 0.9 and 1.1 "
+ "(Perturbing may take 8 minutes and saving may take 20 minutes)"
+ )
+ cut_set = (
+ cut_set
+ + cut_set.perturb_speed(0.9)
+ + cut_set.perturb_speed(1.1)
+ )
+ logging.info(f"Saving to {raw_cuts_path}")
+ cut_set.to_file(raw_cuts_path)
+
+
+def main():
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ preprocess_giga_speech()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/local/train_bpe_model.py b/egs/gigaspeech/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/gigaspeech/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh
new file mode 100755
index 000000000..fd2532741
--- /dev/null
+++ b/egs/gigaspeech/ASR/prepare.sh
@@ -0,0 +1,325 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+nj=15
+stage=0
+stop_stage=100
+
+# Split XL subset to a number of pieces (about 2000)
+# This is to avoid OOM during feature extraction.
+num_per_split=50
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/GigaSpeech
+# You can find audio, dict, GigaSpeech.json inside it.
+# You can apply for the download credentials by following
+# https://github.com/SpeechColab/GigaSpeech#download
+#
+# - $dl_dir/lm
+# This directory contains the language model downloaded from
+# https://huggingface.co/wgb14/gigaspeech_lm
+#
+# - 3gram_pruned_1e7.arpa.gz
+# - 4gram.arpa.gz
+# - lexicon.txt
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ 500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "stage -1: Download LM"
+ # We assume that you have installed the git-lfs, if not, you could install it
+ # using: `sudo apt-get install git-lfs && git-lfs install`
+ [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
+ git clone https://huggingface.co/wgb14/gigaspeech_lm $dl_dir/lm
+ gunzip -c $dl_dir/lm/3gram_pruned_1e7.arpa.gz > $dl_dir/lm/3gram_pruned_1e7.arpa
+ gunzip -c $dl_dir/lm/4gram.arpa.gz > $dl_dir/lm/4gram.arpa
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ [ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech
+
+ # If you have pre-downloaded it to /path/to/GigaSpeech,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech
+ #
+ if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then
+ # Check credentials.
+ if [ ! -f $dl_dir/password ]; then
+ echo -n "$0: Please apply for the download credentials by following"
+ echo -n "https://github.com/SpeechColab/GigaSpeech#download"
+ echo " and save it to $dl_dir/password."
+ exit 1;
+ fi
+ PASSWORD=`cat $dl_dir/password 2>/dev/null`
+ if [ -z "$PASSWORD" ]; then
+ echo "$0: Error, $dl_dir/password is empty."
+ exit 1;
+ fi
+ PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1`
+ if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then
+ echo "$0: Error, invalid $dl_dir/password."
+ exit 1;
+ fi
+ # Download XL, DEV and TEST sets by default.
+ lhotse download gigaspeech --subset auto --host tsinghua \
+ $dl_dir/password $dl_dir/GigaSpeech
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare GigaSpeech manifest (may take 15 minutes)"
+ # We assume that you have downloaded the GigaSpeech corpus
+ # to $dl_dir/GigaSpeech
+ mkdir -p data/manifests
+ lhotse prepare gigaspeech --subset auto -j $nj \
+ $dl_dir/GigaSpeech data/manifests
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "State 3: Preprocess GigaSpeech manifest"
+ if [ ! -f data/fbank/.preprocess_complete ]; then
+ python3 ./local/preprocess_gigaspeech.py
+ touch data/fbank/.preprocess_complete
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
+ python3 ./local/compute_fbank_gigaspeech_dev_test.py
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Split XL subset into pieces (may take 30 minutes)"
+ split_dir=data/fbank/XL_split
+ if [ ! -f $split_dir/.split_completed ]; then
+ lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $num_per_split
+ touch $split_dir/.split_completed
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Compute features for XL"
+ num_splits=$(find data/fbank/XL_split -name "cuts_XL_raw.*.jsonl.gz" | wc -l)
+ python3 ./local/compute_fbank_gigaspeech_splits.py \
+ --num-workers 20 \
+ --batch-duration 600 \
+ --num-splits $num_splits
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Combine features for XL (may take 3 hours)"
+ if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then
+ pieces=$(find data/fbank/XL_split -name "cuts_XL.*.jsonl.gz")
+ lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz
+ fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Compute fbank for musan"
+ mkdir -p data/fbank
+ ./local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Prepare phone based lang"
+ lang_dir=data/lang_phone
+ mkdir -p $lang_dir
+
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/lexicon.txt |
+ sort | uniq > $lang_dir/lexicon.txt
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang.py --lang-dir $lang_dir
+ fi
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
+ | jq '.text' \
+ | sed 's/"//g' \
+ > $lang_dir/transcript_words.txt
+
+ # Delete utterances with garbage meta tags
+ garbage_utterance_tags=" "
+ for tag in $garbage_utterance_tags; do
+ sed -i "/${tag}/d" $lang_dir/transcript_words.txt
+ done
+
+ # Delete punctuations in utterances
+ punctuation_tags=" "
+ for tag in $punctuation_tags; do
+ sed -i "s/${tag}//g" $lang_dir/transcript_words.txt
+ done
+
+ # Ensure space only appears once
+ sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
+ sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
+ fi
+
+ cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' > $lang_dir/words.txt
+ (echo '!SIL'; echo ''; echo ''; ) |
+ cat - $lang_dir/words.txt | sort | uniq | awk '
+ BEGIN {
+ print " 0";
+ }
+ {
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ printf("%s %d\n", $1, NR);
+ }
+ END {
+ printf("#0 %d\n", NR+1);
+ printf(" %d\n", NR+2);
+ printf(" %d\n", NR+3);
+ }' > $lang_dir/words || exit 1;
+ mv $lang_dir/words $lang_dir/words.txt
+fi
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+ # We reuse words.txt from phone based lexicon
+ # so that the two can share G.pt later.
+ cp data/lang_phone/{words.txt,transcript_words.txt} $lang_dir
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+ fi
+ done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Prepare bigram P"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+ ./local/convert_transcript_words_to_tokens.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --transcript $lang_dir/transcript_words.txt \
+ --oov "" \
+ > $lang_dir/transcript_tokens.txt
+ fi
+
+ if [ ! -f $lang_dir/P.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order 2 \
+ -text $lang_dir/transcript_tokens.txt \
+ -lm $lang_dir/P.arpa
+ fi
+
+ if [ ! -f $lang_dir/P.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/tokens.txt" \
+ --disambig-symbol='#0' \
+ --max-order=2 \
+ $lang_dir/P.arpa > $lang_dir/P.fst.txt
+ fi
+ done
+fi
+
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 12: Prepare G"
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ mkdir -p data/lm
+
+ if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ $dl_dir/lm/3gram_pruned_1e7.arpa > data/lm/G_3_gram.fst.txt
+ fi
+
+ if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+ # It is used for LM rescoring
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=4 \
+ $dl_dir/lm/4gram.arpa > data/lm/G_4_gram.fst.txt
+ fi
+fi
+
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+ log "Stage 13: Compile HLG"
+ ./local/compile_hlg.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir
+ done
+fi
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
new file mode 100644
index 000000000..c87686e1e
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -0,0 +1,419 @@
+# Copyright 2021 Piotr Żelasko
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class GigaSpeechAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it "
+ "with training dataset. ",
+ )
+
+ # GigaSpeech specific arguments
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="XL",
+ help="Select the GigaSpeech subset (XS|S|M|L|XL)",
+ )
+ group.add_argument(
+ "--small-dev",
+ type=str2bool,
+ default=False,
+ help="Should we use only 1000 utterances for dev "
+ "(speeds up training)",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan_cuts.jsonl.gz"
+ )
+ transforms.append(
+ CutMix(
+ cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
+ )
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
+ )
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ Fbank(FbankConfig(num_mel_bins=80))
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info(f"About to get train_{self.args.subset} cuts")
+ path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
+ cuts_train = CutSet.from_jsonl_lazy(path)
+ return cuts_train
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_DEV.jsonl.gz"
+ )
+ if self.args.small_dev:
+ return cuts_valid.subset(first=1000)
+ else:
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py
new file mode 120000
index 000000000..a65957180
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
new file mode 100755
index 000000000..ce5116336
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -0,0 +1,577 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search
+./pruned_transducer_stateless2/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from gigaspeech_scoring import asr_text_post_processing
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=29,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=8,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[List[str], List[str]]],
+) -> List[Tuple[List[str], List[str]]]:
+ new_results = []
+ for ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ x=feature, x_lens=feature_lens
+ )
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif (
+ params.decoding_method == "greedy_search"
+ and params.max_sym_per_frame == 1
+ ):
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for hyp_words, ref_text in zip(hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir
+ / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ dev_cuts = gigaspeech.dev_cuts()
+ test_cuts = gigaspeech.test_cuts()
+
+ dev_dl = gigaspeech.test_dataloaders(dev_cuts)
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py
new file mode 120000
index 000000000..722e1c894
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py
new file mode 120000
index 000000000..f58253127
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
new file mode 100755
index 000000000..cff9c7377
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+Usage:
+./pruned_transducer_stateless2/export.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 20 \
+ --avg 10
+
+It will generate a file exp_dir/pretrained.pt
+
+To use the generated file with `pruned_transducer_stateless2/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./pruned_transducer_stateless2/decode.py \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 100 \
+ --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from train import get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ return parser
+
+
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ model.to(device)
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+
+ model.eval()
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torch.jit.script")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = (
+ "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ )
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py
new file mode 120000
index 000000000..a6a4d12b1
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py
@@ -0,0 +1 @@
+../conformer_ctc/gigaspeech_scoring.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py
new file mode 120000
index 000000000..9052f3cbb
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py
new file mode 120000
index 000000000..a99e74334
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/model.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py
new file mode 120000
index 000000000..0a2f285aa
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py
new file mode 120000
index 000000000..c10cdfe12
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
new file mode 100755
index 000000000..83ae25561
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py
@@ -0,0 +1,974 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang
+# Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 0 \
+ --use_fp16 1 \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from conformer import Conformer
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import save_checkpoint_with_global_batch_idx
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[
+ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
+]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ transducer_stateless2/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="The initial learning rate. This value should not need to be changed.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)"
+ "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=20,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 500,
+ "reset_interval": 2000,
+ "valid_interval": 20000,
+ # parameters for conformer
+ "feature_dim": 80,
+ "subsampling_factor": 4,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for decoder
+ "decoder_dim": 512,
+ # parameters for joiner
+ "joiner_dim": 512,
+ # parameters for Noam
+ "model_warm_step": 20000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Conformer and Transformer
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ subsampling_factor=params.subsampling_factor,
+ d_model=params.encoder_dim,
+ nhead=params.nhead,
+ dim_feedforward=params.dim_feedforward,
+ num_encoder_layers=params.num_encoder_layers,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 0:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+ warmup: float = 1.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ warmup=warmup,
+ )
+ # after the main warmup step, we keep pruned_loss_scale small
+ # for the same amount of time (model_warm_step), to avoid
+ # overwhelming the simple_loss and causing it to diverge,
+ # in case it had not fully learned the alignment yet.
+ pruned_loss_scale = (
+ 0.0
+ if warmup < 1.0
+ else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
+ )
+ loss = (
+ params.simple_loss_scale * simple_loss
+ + pruned_loss_scale * pruned_loss
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (
+ (feature_lens // params.subsampling_factor).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 30:
+ return
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+ model.device = device
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ diagnostic = diagnostics.attach_diagnostics(model)
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ train_cuts = gigaspeech.train_cuts()
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = gigaspeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = gigaspeech.dev_cuts()
+ valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ scheduler.step_epoch(epoch)
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: nn.Module,
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 0 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/ASR/shared b/egs/gigaspeech/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/gigaspeech/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/librispeech/ASR/.gitignore b/egs/librispeech/ASR/.gitignore
new file mode 100644
index 000000000..5592679cc
--- /dev/null
+++ b/egs/librispeech/ASR/.gitignore
@@ -0,0 +1 @@
+log-*
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index c8ee98d7d..cbdee53e6 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -1,19 +1,30 @@
-
# Introduction
-Please refer to
-for how to run models in this recipe.
+Please refer to for how to run models in this recipe.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
-| | Encoder | Decoder |
-|------------------------|-----------|--------------------|
-| `transducer` | Conformer | LSTM |
-| `transducer_stateless` | Conformer | Embedding + Conv1d |
-| `transducer_lstm ` | LSTM | LSTM |
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
+| `transducer` | Conformer | LSTM | |
+| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss |
+| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss |
+| `transducer_lstm` | LSTM | LSTM | |
+| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
+| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
+| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |
+| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data |
+| `pruned_transducer_stateless4` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless2 + save averaged models periodically during training |
+| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
+| `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert|
+| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
+| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
+| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
diff --git a/egs/librispeech/ASR/RESULTS-100hours.md b/egs/librispeech/ASR/RESULTS-100hours.md
new file mode 100644
index 000000000..3a064e69d
--- /dev/null
+++ b/egs/librispeech/ASR/RESULTS-100hours.md
@@ -0,0 +1,102 @@
+# Results for train-clean-100
+
+This page shows the WERs for test-clean/test-other using only
+train-clean-100 subset as training data.
+
+## Distillation with hubert
+### 2022-05-27
+Related models/log/tensorboard:
+https://huggingface.co/GuoLiyong/stateless6_baseline_vs_disstillation
+
+Following results are obtained by ./distillation_with_hubert.sh
+
+The only differences is in pruned_transducer_stateless6/train.py.
+
+For baseline: set enable_distillation=False
+
+For distillation: set enable_distillation=True (the default)
+
+Decoding method is modified beam search.
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|------------------------------------------|
+| baseline no vq distillation | 7.09 | 18.88 | --epoch 20, --avg 10, --max-duration 200 |
+| baseline no vq distillation | 6.83 | 18.19 | --epoch 30, --avg 10, --max-duration 200 |
+| baseline no vq distillation | 6.73 | 17.79 | --epoch 40, --avg 10, --max-duration 200 |
+| baseline no vq distillation | 6.75 | 17.68 | --epoch 50, --avg 10, --max-duration 200 |
+| distillation with hubert | 5.82 | 15.98 | --epoch 20, --avg 10, --max-duration 200 |
+| distillation with hubert | 5.52 | 15.15 | --epoch 30, --avg 10, --max-duration 200 |
+| distillation with hubert | 5.45 | 14.94 | --epoch 40, --avg 10, --max-duration 200 |
+| distillation with hubert | 5.50 | 14.77 | --epoch 50, --avg 10, --max-duration 200 |
+
+## Conformer encoder + embedding decoder
+
+### 2022-02-21
+
+Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`.
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|------------------------------------------|
+| greedy search (max sym per frame 1) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| greedy search (max sym per frame 2) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| greedy search (max sym per frame 3) | 6.34 | 16.7 | --epoch 57, --avg 17, --max-duration 100 |
+| modified beam search (beam size 4) | 6.31 | 16.3 | --epoch 57, --avg 17, --max-duration 100 |
+
+
+The training command for reproducing is given below:
+
+```bash
+cd egs/librispeech/ASR/
+./prepare.sh
+./prepare_giga_speech.sh
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./transducer_stateless_multi_datasets/train.py \
+ --world-size 2 \
+ --num-epochs 60 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --full-libri 0 \
+ --max-duration 300 \
+ --lr-factor 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --modified-transducer-prob 0.25
+ --giga-prob 0.2
+```
+
+The decoding command is given below:
+
+```bash
+for epoch in 57; do
+ for avg in 17; do
+ for sym in 1 2 3; do
+ ./transducer_stateless_multi_datasets/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 100 \
+ --context-size 2 \
+ --max-sym-per-frame $sym
+ done
+ done
+done
+
+epoch=57
+avg=17
+./transducer_stateless_multi_datasets/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir transducer_stateless_multi_datasets/exp-100-2 \
+ --bpe-model ./data/lang_bpe_500/bpe.model \
+ --max-duration 100 \
+ --context-size 2 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+```
+
+The tensorboard log is available at
+
+
+A pre-trained model and decoding logs can be found at
+
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index 1476c0528..b10ae98e6 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1,63 +1,1957 @@
## Results
-### LibriSpeech BPE training results (Transducer)
+#### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
-#### Conformer encoder + embedding decoder
+[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)
-Using commit `14c93add507982306f5a478cd144e0e32e0f970d`.
+It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module and simplified memory bank for streaming ASR.
+It is modified from [torchaudio](https://github.com/pytorch/audio).
+
+See for more details.
+
+#### With lower latency setup, training on full librispeech
+
+In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
+
+The WERs are:
+
+| | test-clean | test-other | comment | decoding mode |
+|-------------------------------------|------------|------------|----------------------|----------------------|
+| greedy search (max sym per frame 1) | 3.5 | 9.09 | --epoch 30 --avg 10 | simulated streaming |
+| greedy search (max sym per frame 1) | 3.57 | 9.1 | --epoch 30 --avg 10 | streaming |
+| fast beam search | 3.5 | 8.91 | --epoch 30 --avg 10 | simulated streaming |
+| fast beam search | 3.54 | 8.91 | --epoch 30 --avg 10 | streaming |
+| modified beam search | 3.43 | 8.86 | --epoch 30 --avg 10 | simulated streaming |
+| modified beam search | 3.48 | 8.88 | --epoch 30 --avg 10 | streaming |
+
+The training command is:
+
+```bash
+./conv_emformer_transducer_stateless2/train.py \
+ --world-size 6 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 280 \
+ --master-port 12321 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32
+```
+
+The tensorboard log can be found at
+
+
+The simulated streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The simulated streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The simulated streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+The streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+#### With higher latency setup, training on full librispeech
+
+In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
+
+The WERs are:
+
+| | test-clean | test-other | comment | decoding mode |
+|-------------------------------------|------------|------------|----------------------|----------------------|
+| greedy search (max sym per frame 1) | 3.3 | 8.71 | --epoch 30 --avg 10 | simulated streaming |
+| greedy search (max sym per frame 1) | 3.35 | 8.65 | --epoch 30 --avg 10 | streaming |
+| fast beam search | 3.27 | 8.58 | --epoch 30 --avg 10 | simulated streaming |
+| fast beam search | 3.31 | 8.48 | --epoch 30 --avg 10 | streaming |
+| modified beam search | 3.26 | 8.56 | --epoch 30 --avg 10 | simulated streaming |
+| modified beam search | 3.29 | 8.47 | --epoch 30 --avg 10 | streaming |
+
+The training command is:
+
+```bash
+./conv_emformer_transducer_stateless2/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --full-libri 1 \
+ --max-duration 280 \
+ --master-port 12321 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32
+```
+
+The tensorboard log can be found at
+
+
+The simulated streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The simulated streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The simulated streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless2/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+The streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless2/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless2/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 64 \
+ --cnn-module-kernel 31 \
+ --left-context-length 64 \
+ --right-context-length 16 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+
+### LibriSpeech BPE training results (Pruned Stateless Streaming Conformer RNN-T)
+
+#### [pruned_transducer_stateless](./pruned_transducer_stateless)
+
+See for more details.
+
+##### Training on full librispeech
+The WERs are (the number in the table formatted as test-clean & test-other):
+
+We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
+
+| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
+|----------------------|--------------|----------------|----------------|----------------|----------------|
+| greedy search | 32 | 4.74 & 11.38 | 4.57 & 10.86 | 4.18 & 10.37 | 3.87 & 9.85 |
+| greedy search | 64 | 4.74 & 11.25 | 4.48 & 10.72 | 4.1 & 10.24 | 3.85 & 9.73 |
+| fast beam search | 32 | 4.75 & 11.1 | 4.48 & 10.65 | 4.12 & 10.18 | 3.95 & 9.67 |
+| fast beam search | 64 | 4.7 & 11 | 4.37 & 10.49 | 4.07 & 10.04 | 3.89 & 9.53 |
+| modified beam search | 32 | 4.64 & 10.94 | 4.38 & 10.51 | 4.11 & 10.14 | 3.87 & 9.61 |
+| modified beam search | 64 | 4.59 & 10.81 | 4.29 & 10.39 | 4.02 & 10.02 | 3.84 & 9.43 |
+
+**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
+
+The training command is:
+
+```bash
+./pruned_transducer_stateless/train.py \
+ --exp-dir pruned_transducer_stateless/exp \
+ --full-libri 1 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --short-chunk-size 20 \
+ --num-left-chunks 4 \
+ --max-duration 300 \
+ --world-size 4 \
+ --start-epoch 0 \
+ --num-epochs 25
+```
+
+You can find the tensorboard log here
+
+The decoding command is:
+```bash
+decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
+
+for chunk in 2 4 8 16; do
+ for left in 32 64; do
+ ./pruned_transducer_stateless/decode.py \
+ --simulate-streaming 1 \
+ --decode-chunk-size ${chunk} \
+ --left-context ${left} \
+ --causal-convolution 1 \
+ --epoch 24 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless/exp \
+ --max-sym-per-frame 1 \
+ --max-duration 1000 \
+ --decoding-method ${decoding_method}
+ done
+done
+```
+
+Pre-trained models, training and decoding logs, and decoding results are available at
+
+#### [pruned_transducer_stateless2](./pruned_transducer_stateless2)
+
+See for more details.
+
+##### Training on full librispeech
+The WERs are (the number in the table formatted as test-clean & test-other):
+
+We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
+
+| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
+|----------------------|--------------|----------------|----------------|----------------|----------------|
+| greedy search | 32 | 4.2 & 10.64 | 3.97 & 10.03 | 3.83 & 9.58 | 3.7 & 9.11 |
+| greedy search | 64 | 4.16 & 10.5 | 3.93 & 9.99 | 3.73 & 9.45 | 3.63 & 9.04 |
+| fast beam search | 32 | 4.13 & 10.3 | 3.93 & 9.82 | 3.8 & 9.35 | 3.62 & 8.93 |
+| fast beam search | 64 | 4.13 & 10.22 | 3.89 & 9.68 | 3.73 & 9.27 | 3.52 & 8.82 |
+| modified beam search | 32 | 4.02 & 10.22 | 3.9 & 9.71 | 3.74 & 9.33 | 3.59 & 8.87 |
+| modified beam search | 64 | 4.05 & 10.08 | 3.81 & 9.67 | 3.68 & 9.21 | 3.56 & 8.77 |
+
+**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless2/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
+
+The training command is:
+
+```bash
+./pruned_transducer_stateless2/train.py \
+ --exp-dir pruned_transducer_stateless2/exp \
+ --full-libri 1 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --short-chunk-size 20 \
+ --num-left-chunks 4 \
+ --max-duration 300 \
+ --world-size 4 \
+ --start-epoch 0 \
+ --num-epochs 25
+```
+
+You can find the tensorboard log here
+
+The decoding command is:
+```bash
+decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
+
+for chunk in 2 4 8 16; do
+ for left in 32 64; do
+ ./pruned_transducer_stateless2/decode.py \
+ --simulate-streaming 1 \
+ --decode-chunk-size ${chunk} \
+ --left-context ${left} \
+ --causal-convolution 1 \
+ --epoch 24 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless2/exp \
+ --max-sym-per-frame 1 \
+ --max-duration 1000 \
+ --decoding-method ${decoding_method}
+ done
+done
+```
+
+Pre-trained models, training and decoding logs, and decoding results are available at
+
+#### [pruned_transducer_stateless3](./pruned_transducer_stateless3)
+
+See for more details.
+
+##### Training on full librispeech (**Use giga_prob = 0.5**)
+
+The WERs are (the number in the table formatted as test-clean & test-other):
+
+| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
+|----------------------|--------------|----------------|----------------|----------------|----------------|
+| greedy search | 32 | 3.7 & 9.53 | 3.45 & 8.88 | 3.28 & 8.45 | 3.13 & 7.93 |
+| greedy search | 64 | 3.69 & 9.36 | 3.39 & 8.68 | 3.28 & 8.19 | 3.08 & 7.83 |
+| fast beam search | 32 | 3.71 & 9.18 | 3.36 & 8.65 | 3.23 & 8.23 | 3.17 & 7.78 |
+| fast beam search | 64 | 3.61 & 9.03 | 3.46 & 8.43 | 3.2 & 8.0 | 3.11 & 7.63 |
+| modified beam search | 32 | 3.56 & 9.08 | 3.34 & 8.58 | 3.21 & 8.14 | 3.06 & 7.73 |
+| modified beam search | 64 | 3.55 & 8.86 | 3.29 & 8.34 | 3.16 & 8.01 | 3.05 & 7.57 |
+
+**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless3/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
+
+The training command is (Note: this model was trained with mix-precision training):
+
+```bash
+./pruned_transducer_stateless3/train.py \
+ --exp-dir pruned_transducer_stateless3/exp \
+ --full-libri 1 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --short-chunk-size 32 \
+ --num-left-chunks 4 \
+ --max-duration 300 \
+ --world-size 4 \
+ --use-fp16 1 \
+ --start-epoch 0 \
+ --num-epochs 37 \
+ --num-workers 2 \
+ --giga-prob 0.5
+```
+
+You can find the tensorboard log here
+
+The decoding command is:
+```bash
+decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
+
+for chunk in 2 4 8 16; do
+ for left in 32 64; do
+ ./pruned_transducer_stateless3/decode.py \
+ --simulate-streaming 1 \
+ --decode-chunk-size ${chunk} \
+ --left-context ${left} \
+ --causal-convolution 1 \
+ --epoch 36 \
+ --avg 8 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-sym-per-frame 1 \
+ --max-duration 1000 \
+ --decoding-method ${decoding_method}
+ done
+done
+```
+
+Pre-trained models, training and decoding logs, and decoding results are available at
+
+##### Training on full librispeech (**Use giga_prob = 0.9**)
+
+The WERs are (the number in the table formatted as test-clean & test-other):
+
+| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
+|----------------------|--------------|----------------|----------------|----------------|----------------|
+| greedy search | 32 | 3.25 & 8.2 | 3.07 & 7.67 | 2.91 & 7.28 | 2.8 & 6.89 |
+| greedy search | 64 | 3.22 & 8.12 | 3.05 & 7.59 | 2.91 & 7.07 | 2.78 & 6.81 |
+| fast beam search | 32 | 3.26 & 8.2 | 3.06 & 7.56 | 2.98 & 7.08 | 2.77 & 6.75 |
+| fast beam search | 64 | 3.24 & 8.09 | 3.06 & 7.43 | 2.88 & 7.03 | 2.73 & 6.68 |
+| modified beam search | 32 | 3.13 & 7.91 | 2.99 & 7.45 | 2.83 & 6.98 | 2.68 & 6.75 |
+| modified beam search | 64 | 3.08 & 7.8 | 2.97 & 7.37 | 2.81 & 6.82 | 2.66 & 6.67 |
+
+**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless3/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
+
+The training command is:
+
+```bash
+./pruned_transducer_stateless3/train.py \
+ --exp-dir pruned_transducer_stateless3/exp \
+ --full-libri 1 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --short-chunk-size 25 \
+ --num-left-chunks 8 \
+ --max-duration 300 \
+ --world-size 8 \
+ --start-epoch 0 \
+ --num-epochs 26 \
+ --num-workers 2 \
+ --giga-prob 0.9
+```
+
+You can find the tensorboard log here
+
+The decoding command is:
+```bash
+decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
+
+for chunk in 2 4 8 16; do
+ for left in 32 64; do
+ ./pruned_transducer_stateless3/decode.py \
+ --simulate-streaming 1 \
+ --decode-chunk-size ${chunk} \
+ --left-context ${left} \
+ --causal-convolution 1 \
+ --epoch 25 \
+ --avg 12 \
+ --exp-dir ./pruned_transducer_stateless3/exp \
+ --max-sym-per-frame 1 \
+ --max-duration 1000 \
+ --decoding-method ${decoding_method}
+ done
+done
+```
+
+Pre-trained models, training and decoding logs, and decoding results are available at
+
+#### [pruned_transducer_stateless4](./pruned_transducer_stateless4)
+
+See for more details.
+
+##### Training on full librispeech
+The WERs are (the number in the table formatted as test-clean & test-other):
+
+We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
+
+| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
+|----------------------|--------------|----------------|----------------|----------------|----------------|
+| greedy search | 32 | 3.96 & 10.45 | 3.73 & 9.97 | 3.54 & 9.56 | 3.45 & 9.08 |
+| greedy search | 64 | 3.9 & 10.34 | 3.7 & 9.9 | 3.53 & 9.41 | 3.39 & 9.03 |
+| fast beam search | 32 | 3.9 & 10.09 | 3.69 & 9.65 | 3.58 & 9.28 | 3.46 & 8.91 |
+| fast beam search | 64 | 3.82 & 10.03 | 3.67 & 9.56 | 3.51 & 9.18 | 3.43 & 8.78 |
+| modified beam search | 32 | 3.78 & 10.0 | 3.63 & 9.54 | 3.43 & 9.29 | 3.39 & 8.84 |
+| modified beam search | 64 | 3.76 & 9.95 | 3.54 & 9.48 | 3.4 & 9.13 | 3.33 & 8.74 |
+
+**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless4/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
+
+The training command is:
+
+```bash
+./pruned_transducer_stateless4/train.py \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --full-libri 1 \
+ --dynamic-chunk-training 1 \
+ --causal-convolution 1 \
+ --short-chunk-size 20 \
+ --num-left-chunks 4 \
+ --max-duration 300 \
+ --world-size 4 \
+ --start-epoch 1 \
+ --num-epochs 25
+```
+
+You can find the tensorboard log here
+
+The decoding command is:
+```bash
+decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
+
+for chunk in 2 4 8 16; do
+ for left in 32 64; do
+ ./pruned_transducer_stateless4/decode.py \
+ --simulate-streaming 1 \
+ --decode-chunk-size ${chunk} \
+ --left-context ${left} \
+ --causal-convolution 1 \
+ --epoch 25 \
+ --avg 3 \
+ --exp-dir ./pruned_transducer_stateless4/exp \
+ --max-sym-per-frame 1 \
+ --max-duration 1000 \
+ --decoding-method ${decoding_method}
+ done
+done
+```
+
+Pre-trained models, training and decoding logs, and decoding results are available at
+
+
+### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
+
+[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
+
+It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR.
+It is modified from [torchaudio](https://github.com/pytorch/audio).
+
+See for more details.
+
+#### Training on full librispeech
+
+In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
+
+The WERs are:
+
+| | test-clean | test-other | comment | decoding mode |
+|-------------------------------------|------------|------------|----------------------|----------------------|
+| greedy search (max sym per frame 1) | 3.63 | 9.61 | --epoch 30 --avg 10 | simulated streaming |
+| greedy search (max sym per frame 1) | 3.64 | 9.65 | --epoch 30 --avg 10 | streaming |
+| fast beam search | 3.61 | 9.4 | --epoch 30 --avg 10 | simulated streaming |
+| fast beam search | 3.58 | 9.5 | --epoch 30 --avg 10 | streaming |
+| modified beam search | 3.56 | 9.41 | --epoch 30 --avg 10 | simulated streaming |
+| modified beam search | 3.54 | 9.46 | --epoch 30 --avg 10 | streaming |
+
+The training command is:
+
+```bash
+./conv_emformer_transducer_stateless/train.py \
+ --world-size 6 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --full-libri 1 \
+ --max-duration 300 \
+ --master-port 12321 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32
+```
+
+The tensorboard log can be found at
+
+
+The simulated streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The simulated streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The simulated streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --max-duration 300 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+The streaming decoding command using greedy search is:
+```bash
+./conv_emformer_transducer_stateless/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The streaming decoding command using fast beam search is:
+```bash
+./conv_emformer_transducer_stateless/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The streaming decoding command using modified beam search is:
+```bash
+./conv_emformer_transducer_stateless/streaming_decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir conv_emformer_transducer_stateless/exp \
+ --num-decode-streams 2000 \
+ --num-encoder-layers 12 \
+ --chunk-length 32 \
+ --cnn-module-kernel 31 \
+ --left-context-length 32 \
+ --right-context-length 8 \
+ --memory-size 32 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T)
+
+[pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
+
+Use .
+
+Use [Emformer](https://arxiv.org/abs/2010.10759) from [torchaudio](https://github.com/pytorch/audio)
+for streaming ASR. The Emformer model is imported from torchaudio without modifications.
+
+You can use to deploy it.
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|----------------------------------------|
+| greedy search (max sym per frame 1) | 4.28 | 11.42 | --epoch 39 --avg 6 --max-duration 600 |
+| modified beam search | 4.22 | 11.16 | --epoch 39 --avg 6 --max-duration 600 |
+| fast beam search | 4.29 | 11.26 | --epoch 39 --avg 6 --max-duration 600 |
+
+
+The training commands are:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./pruned_stateless_emformer_rnnt2/train.py \
+ --world-size 8 \
+ --num-epochs 40 \
+ --start-epoch 1 \
+ --exp-dir pruned_stateless_emformer_rnnt2/exp-full \
+ --full-libri 1 \
+ --use-fp16 0 \
+ --max-duration 200 \
+ --prune-range 5 \
+ --lm-scale 0.25 \
+ --master-port 12358 \
+ --num-encoder-layers 18 \
+ --left-context-length 128 \
+ --segment-length 8 \
+ --right-context-length 4
+```
+
+The tensorboard log can be found at
+
+
+The decoding commands are:
+```bash
+for m in greedy_search fast_beam_search modified_beam_search; do
+ for epoch in 39; do
+ for avg in 6; do
+ ./pruned_stateless_emformer_rnnt2/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --exp-dir pruned_stateless_emformer_rnnt2/exp-full \
+ --max-duration 50 \
+ --decoding-method $m \
+ --num-encoder-layers 18 \
+ --left-context-length 128 \
+ --segment-length 8 \
+ --right-context-length 4
+ done
+ done
+done
+```
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+
+### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
+
+[pruned_transducer_stateless5](./pruned_transducer_stateless5)
+
+Same as `Pruned Stateless Transducer 2` but with more layers.
+
+See
+
+Note that models in `pruned_transducer_stateless` and `pruned_transducer_stateless2`
+have about 80 M parameters.
+
+The notations `large` and `medium` below are from the [Conformer](https://arxiv.org/pdf/2005.08100.pdf)
+paper, where the large model has about 118 M parameters and the medium model
+has 30.8 M parameters.
+
+#### Large
+
+Number of model parameters 118129516 (i.e, 118.13 M).
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|----------------------------------------|
+| greedy search (max sym per frame 1) | 2.43 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
+| modified beam search | 2.43 | 5.69 | --epoch 30 --avg 10 --max-duration 600 |
+| fast beam search | 2.43 | 5.67 | --epoch 30 --avg 10 --max-duration 600 |
+
+The training commands are:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --full-libri 1 \
+ --exp-dir pruned_transducer_stateless5/exp-L \
+ --max-duration 300 \
+ --use-fp16 0 \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+```
+
+The tensorboard log can be found at
+
+
+The decoding commands are:
+
+```bash
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./pruned_transducer_stateless5/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless5/exp-L \
+ --max-duration 600 \
+ --decoding-method $method \
+ --max-sym-per-frame 1 \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --use-averaged-model True
+done
+```
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+
+#### Medium
+
+Number of model parameters 30896748 (i.e, 30.9 M).
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|-----------------------------------------|
+| greedy search (max sym per frame 1) | 2.87 | 6.92 | --epoch 30 --avg 10 --max-duration 600 |
+| modified beam search | 2.83 | 6.75 | --epoch 30 --avg 10 --max-duration 600 |
+| fast beam search | 2.81 | 6.76 | --epoch 30 --avg 10 --max-duration 600 |
+
+The training commands are:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --full-libri 1 \
+ --exp-dir pruned_transducer_stateless5/exp-M \
+ --max-duration 300 \
+ --use-fp16 0 \
+ --num-encoder-layers 18 \
+ --dim-feedforward 1024 \
+ --nhead 4 \
+ --encoder-dim 256 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+```
+
+The tensorboard log can be found at
+
+
+The decoding commands are:
+
+```bash
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./pruned_transducer_stateless5/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless5/exp-M \
+ --max-duration 600 \
+ --decoding-method $method \
+ --max-sym-per-frame 1 \
+ --num-encoder-layers 18 \
+ --dim-feedforward 1024 \
+ --nhead 4 \
+ --encoder-dim 256 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --use-averaged-model True
+done
+```
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+
+#### Baseline-2
+
+It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
+layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|-----------------------------------------|
+| greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
+| modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 |
+| fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --full-libri 1 \
+ --exp-dir pruned_transducer_stateless5/exp-B \
+ --max-duration 300 \
+ --use-fp16 0 \
+ --num-encoder-layers 24 \
+ --dim-feedforward 1536 \
+ --nhead 8 \
+ --encoder-dim 384 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+```
+
+The tensorboard log can be found at
+
+
+The decoding commands are:
+
+```bash
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./pruned_transducer_stateless5/decode.py \
+ --epoch 30 \
+ --avg 10 \
+ --exp-dir ./pruned_transducer_stateless5/exp-B \
+ --max-duration 600 \
+ --decoding-method $method \
+ --max-sym-per-frame 1 \
+ --num-encoder-layers 24 \
+ --dim-feedforward 1536 \
+ --nhead 8 \
+ --encoder-dim 384 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --use-averaged-model True
+done
+```
+
+You can find a pretrained model, training logs, decoding logs, and decoding
+results at:
+
+
+
+### LibriSpeech BPE training results (Pruned Stateless Transducer 4)
+
+[pruned_transducer_stateless4](./pruned_transducer_stateless4)
+
+This version saves averaged model during training, and decodes with averaged model.
+
+See for details about the idea of model averaging.
+
+#### Training on full librispeech
+
+See
+
+Using commit `ec0b0e92297cc03fdb09f48cd235e84d2c04156b`.
+
+The WERs are:
+
+| | test-clean | test-other | comment |
+|-------------------------------------|------------|------------|-------------------------------------------------------------------------------|
+| greedy search (max sym per frame 1) | 2.75 | 6.74 | --epoch 30 --avg 6 --use-averaged-model False |
+| greedy search (max sym per frame 1) | 2.69 | 6.64 | --epoch 30 --avg 6 --use-averaged-model True |
+| fast beam search | 2.72 | 6.67 | --epoch 30 --avg 6 --use-averaged-model False |
+| fast beam search | 2.66 | 6.6 | --epoch 30 --avg 6 --use-averaged-model True |
+| modified beam search | 2.67 | 6.68 | --epoch 30 --avg 6 --use-averaged-model False |
+| modified beam search | 2.62 | 6.57 | --epoch 30 --avg 6 --use-averaged-model True |
+
+The training command is:
+
+```bash
+./pruned_transducer_stateless4/train.py \
+ --world-size 6 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --full-libri 1 \
+ --max-duration 300 \
+ --save-every-n 8000 \
+ --keep-last-k 20 \
+ --average-period 100
+```
+
+The tensorboard log can be found at
+
+
+The decoding command using greedy search is:
+```bash
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
+ --avg 6 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 300 \
+ --decoding-method greedy_search \
+ --use-averaged-model True
+```
+
+The decoding command using fast beam search is:
+```bash
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
+ --avg 6 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 300 \
+ --decoding-method fast_beam_search \
+ --use-averaged-model True \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8
+```
+
+The decoding command using modified beam search is:
+```bash
+./pruned_transducer_stateless4/decode.py \
+ --epoch 30 \
+ --avg 6 \
+ --exp-dir pruned_transducer_stateless4/exp \
+ --max-duration 300 \
+ --decoding-method modified_beam_search \
+ --use-averaged-model True \
+ --beam-size 4
+```
+
+Pretrained models, training logs, decoding logs, and decoding results
+are available at
+
+
+#### Training on train-clean-100
+
+See