diff --git a/.github/scripts/.gitignore b/.github/scripts/.gitignore
new file mode 100644
index 000000000..672e477d8
--- /dev/null
+++ b/.github/scripts/.gitignore
@@ -0,0 +1 @@
+piper_phonemize.html
diff --git a/.github/scripts/audioset/AT/run.sh b/.github/scripts/audioset/AT/run.sh
new file mode 100755
index 000000000..87856b64d
--- /dev/null
+++ b/.github/scripts/audioset/AT/run.sh
@@ -0,0 +1,94 @@
+#!/usr/bin/env bash
+
+set -ex
+
+python3 -m pip install onnxoptimizer onnxsim
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/audioset/AT
+
+function test_pretrained() {
+ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
+ repo=$(basename $repo_url)
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ pushd $repo/exp
+ git lfs pull --include pretrained.pt
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh
+ popd
+
+ log "test pretrained.pt"
+
+ python3 zipformer/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --label-dict $repo/data/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+
+ log "test jit export"
+ ls -lh $repo/exp/
+ python3 zipformer/export.py \
+ --exp-dir $repo/exp \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --jit 1
+ ls -lh $repo/exp/
+
+ log "test jit models"
+ python3 zipformer/jit_pretrained.py \
+ --nn-model-filename $repo/exp/jit_script.pt \
+ --label-dict $repo/data/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+
+ log "test onnx export"
+ ls -lh $repo/exp/
+ python3 zipformer/export-onnx.py \
+ --exp-dir $repo/exp \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0
+
+ ls -lh $repo/exp/
+
+ pushd $repo/exp/
+ mv model-epoch-99-avg-1.onnx model.onnx
+ mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
+ popd
+
+ ls -lh $repo/exp/
+
+ log "test onnx models"
+ for m in model.onnx model.int8.onnx; do
+ log "$m"
+ python3 zipformer/onnx_pretrained.py \
+ --model-filename $repo/exp/model.onnx \
+ --label-dict $repo/data/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+ done
+
+ log "prepare data for uploading to huggingface"
+ dst=/icefall/model-onnx
+ mkdir -p $dst
+ cp -v $repo/exp/*.onnx $dst/
+ cp -v $repo/data/* $dst/
+ cp -av $repo/test_wavs $dst
+
+ ls -lh $dst
+ ls -lh $dst/test_wavs
+}
+
+test_pretrained
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index f6a088af1..15f49f826 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
RUN apt-get update -y && \
apt-get install -qq -y \
+ cmake \
ffmpeg \
git \
git-lfs \
@@ -35,7 +36,9 @@ RUN pip install --no-cache-dir \
\
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
+ cython \
dill \
+ espnet_tts_frontend \
graphviz \
kaldi-decoder \
kaldi_native_io \
@@ -44,10 +47,15 @@ RUN pip install --no-cache-dir \
kaldilm \
matplotlib \
multi_quantization \
+ numba \
numpy \
+ onnxoptimizer \
+ onnxsim \
onnx \
onnxmltools \
onnxruntime \
+ piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
+ pypinyin==0.50.0 \
pytest \
sentencepiece>=0.1.96 \
six \
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index bdde97647..77dccb93e 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -6,8 +6,8 @@ import json
def version_gt(a, b):
- a_major, a_minor = a.split(".")[:2]
- b_major, b_minor = b.split(".")[:2]
+ a_major, a_minor = list(map(int, a.split(".")))[:2]
+ b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@@ -18,8 +18,8 @@ def version_gt(a, b):
def version_ge(a, b):
- a_major, a_minor = a.split(".")[:2]
- b_major, b_minor = b.split(".")[:2]
+ a_major, a_minor = list(map(int, a.split(".")))[:2]
+ b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@@ -43,11 +43,15 @@ def get_torchaudio_version(torch_version):
def get_matrix():
- k2_version = "1.24.4.dev20231220"
- kaldifeat_version = "1.25.3.dev20231221"
- version = "1.2"
- python_version = ["3.8", "3.9", "3.10", "3.11"]
- torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ k2_version = "1.24.4.dev20240223"
+ kaldifeat_version = "1.25.4.dev20240223"
+ version = "20240401"
+ python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ torch_version = []
+ torch_version += ["1.13.0", "1.13.1"]
+ torch_version += ["2.0.0", "2.0.1"]
+ torch_version += ["2.1.0", "2.1.1", "2.1.2"]
+ torch_version += ["2.2.0", "2.2.1", "2.2.2"]
matrix = []
for p in python_version:
@@ -57,10 +61,21 @@ def get_matrix():
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
continue
+ # only torch>=2.2.0 supports python 3.12
+ if version_gt(p, "3.11") and not version_gt(t, "2.1"):
+ continue
+
+ k2_version_2 = k2_version
+ kaldifeat_version_2 = kaldifeat_version
+
+ if t == "2.2.2":
+ k2_version_2 = "1.24.4.dev20240328"
+ kaldifeat_version_2 = "1.25.4.dev20240329"
+
matrix.append(
{
- "k2-version": k2_version,
- "kaldifeat-version": kaldifeat_version,
+ "k2-version": k2_version_2,
+ "kaldifeat-version": kaldifeat_version_2,
"version": version,
"python-version": p,
"torch-version": t,
diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py
new file mode 100755
index 000000000..3784d5fa5
--- /dev/null
+++ b/.github/scripts/generate-piper-phonemize-page.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+
+
+def main():
+ prefix = (
+ "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
+ )
+ files = [
+ "piper_phonemize-1.2.0-cp310-cp310-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp311-cp311-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp312-cp312-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp37-cp37m-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp38-cp38-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
+ "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ ]
+ with open("piper_phonemize.html", "w") as f:
+ for file in files:
+ url = prefix + file
+ f.write(f'{file}
\n')
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh
index 7e9bd8a47..b4450afea 100755
--- a/.github/scripts/librispeech/ASR/run.sh
+++ b/.github/scripts/librispeech/ASR/run.sh
@@ -15,9 +15,9 @@ function prepare_data() {
# cause OOM error for CI later.
mkdir -p download/lm
pushd download/lm
- wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
- wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
- wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lm-norm.txt.gz
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-lexicon.txt
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/librispeech-vocab.txt
ls -lh
gunzip librispeech-lm-norm.txt.gz
@@ -64,6 +64,46 @@ function run_diagnostics() {
--print-diagnostics 1
}
+function test_streaming_zipformer_ctc_hlg() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ rm $repo/exp-ctc-rnnt-small/*.onnx
+ ls -lh $repo/exp-ctc-rnnt-small
+
+ # export models to onnx
+ ./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 3 \
+ --exp-dir $repo/exp-ctc-rnnt-small \
+ --causal 1 \
+ --use-ctc 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192
+
+ ls -lh $repo/exp-ctc-rnnt-small
+
+ for wav in 0.wav 1.wav 8k.wav; do
+ python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
+ --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
+ --words $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.fst \
+ $repo/test_wavs/$wav
+ done
+
+ rm -rf $repo
+}
+
function test_pruned_transducer_stateless_2022_03_12() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
@@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
prepare_data
run_diagnostics
+test_streaming_zipformer_ctc_hlg
test_pruned_transducer_stateless_2022_03_12
test_pruned_transducer_stateless2_2022_04_29
test_pruned_transducer_stateless3_2022_04_29
diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh
new file mode 100755
index 000000000..707361782
--- /dev/null
+++ b/.github/scripts/ljspeech/TTS/run.sh
@@ -0,0 +1,157 @@
+#!/usr/bin/env bash
+
+set -ex
+
+python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
+python3 -m pip install espnet_tts_frontend
+python3 -m pip install numba
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/ljspeech/TTS
+
+sed -i.bak s/600/8/g ./prepare.sh
+sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
+sed -i.bak s/500/5/g ./prepare.sh
+git diff
+
+function prepare_data() {
+ # We have created a subset of the data for testing
+ #
+ mkdir download
+ pushd download
+ wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
+ tar xvf LJSpeech-1.1.tar.bz2
+ popd
+
+ ./prepare.sh
+ tree .
+}
+
+function train() {
+ pushd ./vits
+ sed -i.bak s/200/3/g ./train.py
+ git diff .
+ popd
+
+ for t in low medium high; do
+ ./vits/train.py \
+ --exp-dir vits/exp-$t \
+ --model-type $t \
+ --num-epochs 1 \
+ --save-every-n 1 \
+ --num-buckets 2 \
+ --tokens data/tokens.txt \
+ --max-duration 20
+
+ ls -lh vits/exp-$t
+ done
+}
+
+function infer() {
+ for t in low medium high; do
+ ./vits/infer.py \
+ --num-buckets 2 \
+ --model-type $t \
+ --epoch 1 \
+ --exp-dir ./vits/exp-$t \
+ --tokens data/tokens.txt \
+ --max-duration 20
+ done
+}
+
+function export_onnx() {
+ for t in low medium high; do
+ ./vits/export-onnx.py \
+ --model-type $t \
+ --epoch 1 \
+ --exp-dir ./vits/exp-$t \
+ --tokens data/tokens.txt
+
+ ls -lh vits/exp-$t/
+ done
+}
+
+function test_medium() {
+ git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12
+
+ ./vits/export-onnx.py \
+ --model-type medium \
+ --epoch 820 \
+ --exp-dir ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp \
+ --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt
+
+ ls -lh ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp
+
+ ./vits/test_onnx.py \
+ --model-filename ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx \
+ --tokens ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt \
+ --output-filename /icefall/test-medium.wav
+
+ ls -lh /icefall/test-medium.wav
+
+ d=/icefall/vits-icefall-en_US-ljspeech-medium
+ mkdir $d
+ cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/data/tokens.txt $d/
+ cp -v ./icefall-tts-ljspeech-vits-medium-2024-03-12/exp/vits-epoch-820.onnx $d/model.onnx
+
+ rm -rf icefall-tts-ljspeech-vits-medium-2024-03-12
+
+ pushd $d
+ wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+ rm espeak-ng-data.tar.bz2
+ cd ..
+ tar cjf vits-icefall-en_US-ljspeech-medium.tar.bz2 vits-icefall-en_US-ljspeech-medium
+ rm -rf vits-icefall-en_US-ljspeech-medium
+ ls -lh *.tar.bz2
+ popd
+}
+
+function test_low() {
+ git clone https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12
+
+ ./vits/export-onnx.py \
+ --model-type low \
+ --epoch 1600 \
+ --exp-dir ./icefall-tts-ljspeech-vits-low-2024-03-12/exp \
+ --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt
+
+ ls -lh ./icefall-tts-ljspeech-vits-low-2024-03-12/exp
+
+ ./vits/test_onnx.py \
+ --model-filename ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx \
+ --tokens ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt \
+ --output-filename /icefall/test-low.wav
+
+ ls -lh /icefall/test-low.wav
+
+ d=/icefall/vits-icefall-en_US-ljspeech-low
+ mkdir $d
+ cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/data/tokens.txt $d/
+ cp -v ./icefall-tts-ljspeech-vits-low-2024-03-12/exp/vits-epoch-1600.onnx $d/model.onnx
+
+ rm -rf icefall-tts-ljspeech-vits-low-2024-03-12
+
+ pushd $d
+ wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+ rm espeak-ng-data.tar.bz2
+ cd ..
+ tar cjf vits-icefall-en_US-ljspeech-low.tar.bz2 vits-icefall-en_US-ljspeech-low
+ rm -rf vits-icefall-en_US-ljspeech-low
+ ls -lh *.tar.bz2
+ popd
+}
+
+prepare_data
+train
+infer
+export_onnx
+rm -rf vits/exp-{low,medium,high}
+test_medium
+test_low
diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
index a3a2d3080..981b74b76 100755
--- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
+++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
@@ -30,7 +30,7 @@ log "Test exporting to ONNX format"
./pruned_transducer_stateless2/export-onnx.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1
@@ -38,14 +38,14 @@ log "Export to torchscript model"
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit-trace 1
diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml
new file mode 100644
index 000000000..280ef8f8e
--- /dev/null
+++ b/.github/workflows/audioset.yml
@@ -0,0 +1,137 @@
+name: audioset
+
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+concurrency:
+ group: audioset-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
+
+ audioset:
+ needs: generate_build_matrix
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ ls -lh
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
+
+ - name: Run tests
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
+
+ .github/scripts/audioset/AT/run.sh
+
+ - name: Show model files
+ shell: bash
+ run: |
+ sudo chown -R runner ./model-onnx
+ ls -lh ./model-onnx
+ chmod -x ./model-onnx/class_labels_indices.csv
+
+ echo "----------"
+ ls -lh ./model-onnx/*
+
+ - name: Upload model to huggingface
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ uses: nick-fields/retry@v3
+ with:
+ max_attempts: 20
+ timeout_seconds: 200
+ shell: bash
+ command: |
+ git config --global user.email "csukuangfj@gmail.com"
+ git config --global user.name "Fangjun Kuang"
+
+ rm -rf huggingface
+ export GIT_LFS_SKIP_SMUDGE=1
+
+ git clone https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 huggingface
+ cd huggingface
+ git fetch
+ git pull
+ git merge -m "merge remote" --ff origin main
+ cp ../model-onnx/*.onnx ./
+ cp ../model-onnx/*.csv ./
+ cp -a ../model-onnx/test_wavs ./
+ ls -lh
+ git add .
+ git status
+ git commit -m "update models"
+ git status
+
+ git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 main || true
+ rm -rf huggingface
+
+ - name: Prepare for release
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
+ shell: bash
+ run: |
+ d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
+ mv ./model-onnx $d
+ tar cjvf ${d}.tar.bz2 $d
+ ls -lh
+
+ - name: Release exported onnx models
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
+ uses: svenstaro/upload-release-action@v2
+ with:
+ file_glob: true
+ overwrite: true
+ file: sherpa-onnx-*.tar.bz2
+ repo_name: k2-fsa/sherpa-onnx
+ repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
+ tag: audio-tagging-models
+
diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml
index d7fe2c964..c622476f2 100644
--- a/.github/workflows/build-doc.yml
+++ b/.github/workflows/build-doc.yml
@@ -56,11 +56,14 @@ jobs:
- name: Build doc
shell: bash
run: |
+ .github/scripts/generate-piper-phonemize-page.py
cd docs
python3 -m pip install -r ./requirements.txt
make html
touch build/html/.nojekyll
+ cp -v ../piper_phonemize.html ./build/html/
+
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml
index e5d96dcdf..9198cdb7f 100644
--- a/.github/workflows/build-docker-image.yml
+++ b/.github/workflows/build-docker-image.yml
@@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
+ image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml
new file mode 100644
index 000000000..e202d21b5
--- /dev/null
+++ b/.github/workflows/ljspeech.yml
@@ -0,0 +1,102 @@
+name: ljspeech
+
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+concurrency:
+ group: ljspeech-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
+
+ ljspeech:
+ needs: generate_build_matrix
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ ls -lh
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
+
+ - name: Run tests
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
+
+ .github/scripts/ljspeech/TTS/run.sh
+
+ - name: display files
+ shell: bash
+ run: |
+ ls -lh
+
+ - uses: actions/upload-artifact@v4
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
+ with:
+ name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
+ path: ./*.wav
+
+ - uses: actions/upload-artifact@v4
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
+ with:
+ name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
+ path: ./*.wav
+
+ - name: Release exported onnx models
+ if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
+ uses: svenstaro/upload-release-action@v2
+ with:
+ file_glob: true
+ overwrite: true
+ file: vits-icefall-*.tar.bz2
+ repo_name: k2-fsa/sherpa-onnx
+ repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
+ tag: tts-models
+
diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml
index d048923b6..a26e704c5 100644
--- a/.github/workflows/run-docker-image.yml
+++ b/.github/workflows/run-docker-image.yml
@@ -14,13 +14,20 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
+ image: ["torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v2
with:
fetch-depth: 0
+ - name: Free space
+ shell: bash
+ run: |
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml
index fc1dcbfd4..1c37f13ed 100644
--- a/.github/workflows/style_check.yml
+++ b/.github/workflows/style_check.yml
@@ -49,7 +49,7 @@ jobs:
- name: Install Python dependencies
run: |
- python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
+ python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8
@@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}}
run: |
black --check --diff .
+
+ - name: Run isort
+ shell: bash
+ working-directory: ${{github.workspace}}
+ run: |
+ isort --check --diff .
diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml
index 182300dfa..de822b33f 100644
--- a/.github/workflows/yesno.yml
+++ b/.github/workflows/yesno.yml
@@ -59,4 +59,7 @@ jobs:
cd /icefall
git config --global --add safe.directory /icefall
+ python3 -m torch.utils.collect_env
+ python3 -m k2.version
+
.github/scripts/yesno/ASR/run.sh
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1bb38f6ba..70068f9cf 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort
- rev: 5.11.5
+ rev: 5.12.0
hooks:
- id: isort
args: ["--profile=black"]
diff --git a/README.md b/README.md
index 15e9e17e6..770066166 100644
--- a/README.md
+++ b/README.md
@@ -2,46 +2,86 @@
-## Introduction
+# Introduction
-icefall contains ASR recipes for various datasets
-using .
+The icefall project contains speech-related recipes for various datasets
+using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse).
-You can use to deploy models
-trained with icefall.
+You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models
+in icefall; these frameworks also support models not included in icefall; please refer to respective documents for more details.
You can try pre-trained models from within your browser without the need
-to download or install anything by visiting
-See for more details.
+to download or install anything by visiting this [huggingface space](https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition).
+Please refer to [document](https://k2-fsa.github.io/icefall/huggingface/spaces.html) for more details.
-## Installation
+# Installation
-Please refer to
+Please refer to [document](https://icefall.readthedocs.io/en/latest/installation/index.html)
for installation.
-## Recipes
+# Recipes
-Please refer to
-for more information.
+Please refer to [document](https://icefall.readthedocs.io/en/latest/recipes/index.html)
+for more details.
-We provide the following recipes:
+## ASR: Automatic Speech Recognition
+### Supported Datasets
- [yesno][yesno]
- - [LibriSpeech][librispeech]
- - [GigaSpeech][gigaspeech]
- - [AMI][ami]
+
+ - [Aidatatang_200zh][aidatatang_200zh]
- [Aishell][aishell]
- [Aishell2][aishell2]
- [Aishell4][aishell4]
+ - [Alimeeting][alimeeting]
+ - [AMI][ami]
+ - [CommonVoice][commonvoice]
+ - [Corpus of Spontaneous Japanese][csj]
+ - [GigaSpeech][gigaspeech]
+ - [LibriCSS][libricss]
+ - [LibriSpeech][librispeech]
+ - [Libriheavy][libriheavy]
+ - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2]
+ - [PeopleSpeech][peoplespeech]
+ - [SPGISpeech][spgispeech]
+ - [Switchboard][swbd]
- [TIMIT][timit]
- [TED-LIUM3][tedlium3]
- - [Aidatatang_200zh][aidatatang_200zh]
- - [WenetSpeech][wenetspeech]
- - [Alimeeting][alimeeting]
- - [Switchboard][swbd]
- [TAL_CSASR][tal_csasr]
+ - [Voxpopuli][voxpopuli]
+ - [XBMU-AMDO31][xbmu-amdo31]
+ - [WenetSpeech][wenetspeech]
+
+More datasets will be added in the future.
-### yesno
+### Supported Models
+
+The [LibriSpeech][librispeech] recipe supports the most comprehensive set of models, you are welcome to try them out.
+
+#### CTC
+ - TDNN LSTM CTC
+ - Conformer CTC
+ - Zipformer CTC
+
+#### MMI
+ - Conformer MMI
+ - Zipformer MMI
+
+#### Transducer
+ - Conformer-based Encoder
+ - LSTM-based Encoder
+ - Zipformer-based Encoder
+ - LSTM-based Predictor
+ - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
+
+#### Whisper
+ - [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.)
+
+If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
+
+We would like to highlight the performance of some of the recipes here.
+
+### [yesno][yesno]
This is the simplest ASR recipe in `icefall` and can be run on CPU.
Training takes less than 30 seconds and gives you the following WER:
@@ -52,350 +92,264 @@ Training takes less than 30 seconds and gives you the following WER:
We provide a Colab notebook for this recipe: [](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
-### LibriSpeech
+### [LibriSpeech][librispeech]
-Please see
+Please see [RESULTS.md](https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md)
for the **latest** results.
-We provide 5 models for this recipe:
-
-- [conformer CTC model][LibriSpeech_conformer_ctc]
-- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
-- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer]
-- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless]
-- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer]
-
-#### Conformer CTC Model
-
-The best WER we currently have is:
+#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc)
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
-We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
-#### TDNN LSTM CTC Model
-
-The WER for this model is:
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc)
| | test-clean | test-other |
|-----|------------|------------|
| WER | 6.59 | 17.69 |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
-#### Transducer: Conformer encoder + LSTM decoder
+#### [Transducer (Conformer Encoder + LSTM Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer)
-Using Conformer as encoder and LSTM as decoder.
+| | test-clean | test-other |
+|---------------|------------|------------|
+| greedy_search | 3.07 | 7.51 |
-The best WER with greedy search is:
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 3.07 | 7.51 |
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer)
-We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
-
-#### Transducer: Conformer encoder + Embedding decoder
-
-Using Conformer as encoder. The decoder consists of 1 embedding layer
-and 1 convolutional layer.
-
-The best WER using modified beam search with beam size 4 is:
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 2.56 | 6.27 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
-
-We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+| | test-clean | test-other |
+|---------------------------------------|------------|------------|
+| modified_beam_search (`beam_size=4`) | 2.56 | 6.27 |
-#### k2 pruned RNN-T
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+
+
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer)
+
+WER (modified_beam_search `beam_size=4` unless further stated)
+
+1. LibriSpeech-960hr
| Encoder | Params | test-clean | test-other | epochs | devices |
|-----------------|--------|------------|------------|---------|------------|
-| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
-| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
-| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
-| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
+| Zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
+| Zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
+| Zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
+| Zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
+2. LibriSpeech-960hr + GigaSpeech
-#### k2 pruned RNN-T + GigaSpeech
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 1.78 | 4.08 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
-
-#### k2 pruned RNN-T + GigaSpeech + CommonVoice
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 1.90 | 3.98 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
+| Encoder | Params | test-clean | test-other |
+|-----------------|--------|------------|------------|
+| Zipformer | 65.5M | 1.78 | 4.08 |
-### GigaSpeech
+3. LibriSpeech-960hr + GigaSpeech + CommonVoice
-We provide three models for this recipe:
+| Encoder | Params | test-clean | test-other |
+|-----------------|--------|------------|------------|
+| Zipformer | 65.5M | 1.90 | 3.98 |
-- [Conformer CTC model][GigaSpeech_conformer_ctc]
-- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
-- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer]
-#### Conformer CTC
+### [GigaSpeech][gigaspeech]
+
+#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/conformer_ctc)
| | Dev | Test |
|-----|-------|-------|
| WER | 10.47 | 10.58 |
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/pruned_transducer_stateless2)
+
+Conformer Encoder + Stateless Predictor + k2 Pruned RNN-T Loss
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 10.51 | 10.73 |
-| fast beam search | 10.50 | 10.69 |
-| modified beam search | 10.40 | 10.51 |
+| greedy_search | 10.51 | 10.73 |
+| fast_beam_search | 10.50 | 10.69 |
+| modified_beam_search | 10.40 | 10.51 |
-#### Transducer: Zipformer encoder + Embedding decoder
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/zipformer)
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 10.31 | 10.50 |
-| fast beam search | 10.26 | 10.48 |
-| modified beam search | 10.25 | 10.38 |
+| greedy_search | 10.31 | 10.50 |
+| fast_beam_search | 10.26 | 10.48 |
+| modified_beam_search | 10.25 | 10.38 |
-### Aishell
+### [Aishell][aishell]
-We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc],
-[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7],
-
-#### Conformer CTC Model
-
-The best CER we currently have is:
-
-| | test |
-|-----|------|
-| CER | 4.26 |
-
-#### TDNN LSTM CTC Model
-
-The CER for this model is:
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/tdnn_lstm_ctc)
| | test |
|-----|-------|
| CER | 10.16 |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
-#### Transducer Stateless Model
-
-The best CER we currently have is:
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/transducer_stateless)
| | test |
|-----|------|
| CER | 4.38 |
-We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
+
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/zipformer)
+
+WER (modified_beam_search `beam_size=4`)
+
+| Encoder | Params | dev | test | epochs |
+|-----------------|--------|-----|------|---------|
+| Zipformer | 73.4M | 4.13| 4.40 | 55 |
+| Zipformer-small | 30.2M | 4.40| 4.67 | 55 |
+| Zipformer-large | 157.3M | 4.03| 4.28 | 56 |
-### Aishell2
+### [Aishell4][aishell4]
-We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5].
-
-#### Transducer Stateless Model
-
-The best WER we currently have is:
-
-| | dev-ios | test-ios |
-|-----|------------|------------|
-| WER | 5.32 | 5.56 |
-
-
-### Aishell4
-
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
-
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
-
-The best CER we currently have is:
+#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell4/ASR/pruned_transducer_stateless5)
+1 Trained with all subsets:
| | test |
|-----|------------|
| CER | 29.08 |
-
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
-### TIMIT
+### [TIMIT][timit]
-We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc]
-and [TDNN LiGRU CTC model][TIMIT_tdnn_ligru_ctc].
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_lstm_ctc)
-#### TDNN LSTM CTC Model
-
-The best PER we currently have is:
-
-||TEST|
-|--|--|
+| |TEST|
+|---|----|
|PER| 19.71% |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing)
-#### TDNN LiGRU CTC Model
+#### [TDNN LiGRU CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_ligru_ctc)
-The PER for this model is:
-
-||TEST|
-|--|--|
+| |TEST|
+|---|----|
|PER| 17.66% |
-We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
-### TED-LIUM3
+### [TED-LIUM3][tedlium3]
-We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless].
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless)
-#### Transducer Stateless: Conformer encoder + Embedding decoder
-
-The best WER using modified beam search with beam size 4 is:
-
-| | dev | test |
-|-----|-------|--------|
-| WER | 6.91 | 6.33 |
-
-Note: No auxiliary losses are used in the training and no LMs are used in the decoding.
-
-We provide a Colab notebook to run a pre-trained Transducer Stateless model: [](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing)
-
-#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
-
-The best WER using modified beam search with beam size 4 is:
-
-| | dev | test |
-|-----|-------|--------|
-| WER | 6.77 | 6.14 |
-
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
+| | dev | test |
+|--------------------------------------|-------|--------|
+| modified_beam_search (`beam_size=4`) | 6.91 | 6.33 |
-### Aidatatang_200zh
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing)
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2].
+#### [Transducer (pruned_transducer_stateless)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/pruned_transducer_stateless)
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+| | dev | test |
+|--------------------------------------|-------|--------|
+| modified_beam_search (`beam_size=4`) | 6.77 | 6.14 |
+
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
+
+
+### [Aidatatang_200zh][aidatatang_200zh]
+
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2)
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 5.53 | 6.59 |
-| fast beam search | 5.30 | 6.34 |
-| modified beam search | 5.27 | 6.33 |
+| greedy_search | 5.53 | 6.59 |
+| fast_beam_search | 5.30 | 6.34 |
+| modified_beam_search | 5.27 | 6.33 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
-### WenetSpeech
+### [WenetSpeech][wenetspeech]
-We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
-
-#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless2)
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
-| greedy search | 7.80 | 8.75 | 13.49 |
-| modified beam search| 7.76 | 8.71 | 13.41 |
-| fast beam search | 7.94 | 8.74 | 13.80 |
+| greedy_search | 7.80 | 8.75 | 13.49 |
+| fast_beam_search | 7.94 | 8.74 | 13.80 |
+| modified_beam_search | 7.76 | 8.71 | 13.41 |
+
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
+
+#### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5)
-#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
-**Streaming**:
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
| greedy_search | 8.78 | 10.12 | 16.16 |
-| modified_beam_search | 8.53| 9.95 | 15.81 |
| fast_beam_search| 9.01 | 10.47 | 16.28 |
+| modified_beam_search | 8.53| 9.95 | 15.81 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
-### Alimeeting
+### [Alimeeting][alimeeting]
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2].
-
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset)
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/alimeeting/ASR/pruned_transducer_stateless2)
| | Eval | Test-Net |
|----------------------|--------|----------|
-| greedy search | 31.77 | 34.66 |
-| fast beam search | 31.39 | 33.02 |
-| modified beam search | 30.38 | 34.25 |
+| greedy_search | 31.77 | 34.66 |
+| fast_beam_search | 31.39 | 33.02 |
+| modified_beam_search | 30.38 | 34.25 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
-### TAL_CSASR
+### [TAL_CSASR][tal_csasr]
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/tal_csasr/ASR/pruned_transducer_stateless5)
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
-|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
+|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
-## Deployment with C++
+## TTS: Text-to-Speech
-Once you have trained a model in icefall, you may want to deploy it with C++,
-without Python dependencies.
+### Supported Datasets
-Please refer to the documentation
-
+ - [LJSpeech][ljspeech]
+ - [VCTK][vctk]
+
+### Supported Models
+
+ - [VITS](https://arxiv.org/abs/2106.06103)
+
+# Deployment with C++
+
+Once you have trained a model in icefall, you may want to deploy it with C++ without Python dependencies.
+
+Please refer to the [document](https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c)
for how to do this.
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
Please see: [](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing)
-[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
-[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
-[LibriSpeech_transducer]: egs/librispeech/ASR/transducer
-[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless
-[LibriSpeech_zipformer]: egs/librispeech/ASR/zipformer
-[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
-[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
-[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe
-[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5
-[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
-[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc
-[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc
-[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless
-[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless
-[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc
-[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2
-[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer
-[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
-[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
-[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5
-[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
-[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
@@ -411,3 +365,15 @@ Please see: [ is first proposed `here `_
to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
-are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
+are acoustically similar, DR derives the following formula for decoding with Bayes' theorem:
.. math::
@@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
-Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
+Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to
shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
@@ -58,7 +58,7 @@ during decoding for transducer model:
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper `_,
-LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
+LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``.
diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst
index 684fefeb4..8b2586730 100644
--- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst
+++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst
@@ -9,9 +9,9 @@ to improve the word-error-rate of a transducer model.
.. note::
- This tutorial is based on the recipe
+ This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming `_,
- which is a streaming transducer model trained on `LibriSpeech`_.
+ which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue here `icefall `_.
@@ -69,11 +69,11 @@ Training a language model usually takes a long time, we can download a pre-train
.. code-block:: bash
$ # download the external LM
- $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
- $ ln -s pretrained.pt epoch-99.pt
+ $ ln -s pretrained.pt epoch-99.pt
$ popd
.. note::
@@ -85,7 +85,7 @@ Training a language model usually takes a long time, we can download a pre-train
To use shallow fusion for decoding, we can execute the following command:
.. code-block:: bash
-
+
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.29
@@ -133,16 +133,16 @@ The decoding result obtained with the above command are shown below.
$ For test-other, WER of different settings are:
$ beam_size_4 7.08 best for test-other
-The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
+The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
A few parameters can be tuned to further boost the performance of shallow fusion:
-- ``--lm-scale``
+- ``--lm-scale``
- Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
- the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
+ Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
+ the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
+
+- ``--beam-size``
-- ``--beam-size``
-
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
Here, we also show how `--beam-size` effect the WER and decoding time:
@@ -176,4 +176,4 @@ As we see, a larger beam size during shallow fusion improves the WER, but is als
-
+
diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst
index cbd300d9b..2f4bdb3f6 100644
--- a/docs/source/docker/intro.rst
+++ b/docs/source/docker/intro.rst
@@ -34,6 +34,12 @@ which will give you something like below:
.. code-block:: bash
+ "torch2.2.2-cuda12.1"
+ "torch2.2.2-cuda11.8"
+ "torch2.2.1-cuda12.1"
+ "torch2.2.1-cuda11.8"
+ "torch2.2.0-cuda12.1"
+ "torch2.2.0-cuda11.8"
"torch2.1.0-cuda12.1"
"torch2.1.0-cuda11.8"
"torch2.0.0-cuda11.7"
diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst
index a68e9d3ed..e257b915c 100644
--- a/docs/source/for-dummies/environment-setup.rst
+++ b/docs/source/for-dummies/environment-setup.rst
@@ -74,6 +74,10 @@ to install dependencies of `icefall`_:
pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
+ # For users from China
+ # 中国国内用户,如果访问不了 huggingface, 请使用
+ # pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu-cn.html
+
# Install the latest version of lhotse
pip install git+https://github.com/lhotse-speech/lhotse
diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst
index 5a034ef5b..87318f30e 100644
--- a/docs/source/installation/index.rst
+++ b/docs/source/installation/index.rst
@@ -206,6 +206,9 @@ We will install `k2`_ from pre-compiled wheels by following
.. code-block:: bash
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
+ # For users from China
+ # 中国国内用户,如果访问不了 huggingface, 请使用
+ # pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda-cn.html
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in links: https://k2-fsa.github.io/k2/cuda.html
diff --git a/docs/source/recipes/Finetune/adapter/finetune_adapter.rst b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst
new file mode 100644
index 000000000..a94b008f6
--- /dev/null
+++ b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst
@@ -0,0 +1,225 @@
+Finetune from a pre-trained Zipformer model with adapters
+=========================================================
+
+This tutorial shows you how to fine-tune a pre-trained **Zipformer**
+transducer model on a new dataset with adapters.
+Adapters are compact and efficient module that can be integrated into a pre-trained model
+to improve the model's performance on a new domain. Adapters are injected
+between different modules in the well-trained neural network. During training, only the parameters
+in the adapters will be updated. It achieves competitive performance
+while requiring much less GPU memory than full fine-tuning. For more details about adapters,
+please refer to the original `paper `_ for more details.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe
+
+For illustration purpose, we fine-tune the Zipformer transducer model
+pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
+own data for fine-tuning if you create a manifest for your new dataset.
+
+Data preparation
+----------------
+
+Please follow the instructions in the `GigaSpeech recipe `_
+to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
+
+
+Model preparation
+-----------------
+
+We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
+checkpoint of the model can be downloaded via the following command:
+
+.. code-block:: bash
+
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
+ $ git lfs pull --include "pretrained.pt"
+ $ ln -s pretrained.pt epoch-99.pt
+ $ cd ../data/lang_bpe_500
+ $ git lfs pull --include bpe.model
+ $ cd ../../..
+
+Before fine-tuning, let's test the model's WER on the new domain. The following command performs
+decoding on the GigaSpeech test sets:
+
+.. code-block:: bash
+
+ ./zipformer/decode_gigaspeech.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
+ --use-averaged-model 0 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 20.06 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 19.27 best for test
+
+
+Fine-tune with adapter
+----------------------
+
+We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
+The original model parameters remain untouched during training and only the parameters of
+the adapters are updated. The following command starts a fine-tuning experiment with adapters:
+
+.. code-block:: bash
+
+ $ do_finetune=1
+ $ use_adapters=1
+ $ adapter_dim=8
+
+ $ ./zipformer_adapter/train.py \
+ --world-size 2 \
+ --num-epochs 20 \
+ --start-epoch 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --use-fp16 1 \
+ --base-lr 0.045 \
+ --use-adapters $use_adapters --adapter-dim $adapter_dim \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --do-finetune $do_finetune \
+ --master-port 13022 \
+ --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
+ --max-duration 1000
+
+The following arguments are related to fine-tuning:
+
+- ``--do-finetune``
+ If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
+ **Note that if you want to resume your fine-tuning experiment from certain epochs, you
+ need to set this to False.**
+
+- ``use-adapters``
+ If adapters are used during fine-tuning.
+
+- ``--adapter-dim``
+ The bottleneck dimension of the adapter module. Typically a small number.
+
+You should notice that in the training log, the total number of trainale parameters is shown:
+
+.. code-block::
+
+ 2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
+
+The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
+and requires less memory than full fine-tuning.
+
+
+Decoding
+--------
+
+After training, let's test the WERs. To test the WERs on the GigaSpeech set,
+you can execute the following command:
+
+.. code-block:: bash
+
+ $ epoch=20
+ $ avg=10
+ $ use_adapters=1
+ $ adapter_dim=8
+
+ % ./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --max-duration 600 \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 15.44 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 15.42 best for test
+
+
+The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
+
+The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
+to keep the same performance of the original model:
+
+.. code-block:: bash
+
+ $ epoch=20
+ $ avg=1
+ $ use_adapters=0
+ $ adapter_dim=8
+
+ % ./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --max-duration 600 \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --decoding-method greedy_search
+
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 2.23 best for test-clean
+
+ For test, WER of different settings are:
+ greedy_search 4.96 best for test-other
+
+The numbers are the same as reported in `icefall `_. So adapter-based
+fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
+
+
+Export the model
+----------------
+
+After training, the model can be exported to ``onnx`` format easily using the following command:
+
+.. code-block:: bash
+
+ $ use_adapters=1
+ $ adapter_dim=16
+
+ $ ./zipformer_adapter/export-onnx.py \
+ --tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 1 \
+ --epoch 20 \
+ --avg 10 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --num-encoder-layers "2,2,3,4,3,2" \
+ --downsampling-factor "1,2,4,8,4,2" \
+ --feedforward-dim "512,768,1024,1536,1024,768" \
+ --num-heads "4,4,4,8,4,4" \
+ --encoder-dim "192,256,384,512,384,256" \
+ --query-head-dim 32 \
+ --value-head-dim 12 \
+ --pos-head-dim 4 \
+ --pos-dim 48 \
+ --encoder-unmasked-dim "192,192,256,256,256,192" \
+ --cnn-module-kernel "31,31,15,15,15,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --causal False \
+ --chunk-size "16,32,64,-1" \
+ --left-context-frames "64,128,256,-1"
\ No newline at end of file
diff --git a/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst
new file mode 100644
index 000000000..7ca4eb811
--- /dev/null
+++ b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst
@@ -0,0 +1,140 @@
+Finetune from a supervised pre-trained Zipformer model
+======================================================
+
+This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
+transducer model on a new dataset.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe
+
+
+For illustration purpose, we fine-tune the Zipformer transducer model
+pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
+own data for fine-tuning if you create a manifest for your new dataset.
+
+Data preparation
+----------------
+
+Please follow the instructions in the `GigaSpeech recipe `_
+to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
+
+
+Model preparation
+-----------------
+
+We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
+checkpoint of the model can be downloaded via the following command:
+
+.. code-block:: bash
+
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
+ $ git lfs pull --include "pretrained.pt"
+ $ ln -s pretrained.pt epoch-99.pt
+ $ cd ../data/lang_bpe_500
+ $ git lfs pull --include bpe.model
+ $ cd ../../..
+
+Before fine-tuning, let's test the model's WER on the new domain. The following command performs
+decoding on the GigaSpeech test sets:
+
+.. code-block:: bash
+
+ ./zipformer/decode_gigaspeech.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
+ --use-averaged-model 0 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 20.06 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 19.27 best for test
+
+
+Fine-tune
+---------
+
+Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
+Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
+initializing the stateless decoder and joiner from scratch due to the mismatch of the output
+vocabulary). The following command starts a fine-tuning experiment:
+
+.. code-block:: bash
+
+ $ use_mux=0
+ $ do_finetune=1
+
+ $ ./zipformer/finetune.py \
+ --world-size 2 \
+ --num-epochs 20 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
+ --use-fp16 1 \
+ --base-lr 0.0045 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --do-finetune $do_finetune \
+ --use-mux $use_mux \
+ --master-port 13024 \
+ --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
+ --max-duration 1000
+
+The following arguments are related to fine-tuning:
+
+- ``--base-lr``
+ The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
+ otherwise the model may forget the initialization very quickly. A reasonable value should be around
+ 1/10 of the original lr, i.e 0.0045.
+
+- ``--do-finetune``
+ If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
+ **Note that if you want to resume your fine-tuning experiment from certain epochs, you
+ need to set this to False.**
+
+- ``--finetune-ckpt``
+ The path to the pre-trained checkpoint (used for initialization).
+
+- ``--use-mux``
+ If True, mix the fine-tune data with the original training data by using `CutSet.mux `_
+ This helps maintain the model's performance on the original domain if the original training
+ is available. **If you don't have the original training data, please set it to False.**
+
+After fine-tuning, let's test the WERs. You can do this via the following command:
+
+.. code-block:: bash
+
+ $ use_mux=0
+ $ do_finetune=1
+ $ ./zipformer/decode_gigaspeech.py \
+ --epoch 20 \
+ --avg 10 \
+ --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
+ --use-averaged-model 1 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see numbers similar to the ones below:
+
+.. code-block:: text
+
+ For dev, WER of different settings are:
+ greedy_search 13.47 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 13.66 best for test
+
+Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
+on the GigaSpeech test sets.
diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst
new file mode 100644
index 000000000..7f36d2687
--- /dev/null
+++ b/docs/source/recipes/Finetune/index.rst
@@ -0,0 +1,16 @@
+Fine-tune a pre-trained model
+=============================
+
+After pre-training on public available datasets, the ASR model is already capable of
+performing general speech recognition with relatively high accuracy. However, the accuracy
+could be still low on certain domains that are quite different from the original training
+set. In this case, we can fine-tune the model with a small amount of additional labelled
+data to improve the performance on new domains.
+
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Table of Contents
+
+ from_supervised/finetune_zipformer
+ adapter/finetune_adapter
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
index d08aa0f47..9499a3aea 100644
--- a/docs/source/recipes/TTS/ljspeech/vits.rst
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -1,11 +1,11 @@
-VITS
+VITS-LJSpeech
===============
This tutorial shows you how to train an VITS model
with the `LJSpeech `_ dataset.
.. note::
-
+
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
@@ -13,6 +13,14 @@ with the `LJSpeech `_ dataset.
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+Install extra dependencies
+--------------------------
+
+.. code-block:: bash
+
+ pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
+ pip install numba espnet_tts_frontend
+
Data preparation
----------------
@@ -56,7 +64,8 @@ Training
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
- --tokens data/tokens.txt
+ --tokens data/tokens.txt \
+ --model-type high \
--max-duration 500
.. note::
@@ -64,6 +73,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.
+.. warning::
+
+ If you want a model that runs faster on CPU, please use ``--model-type low``
+ or ``--model-type medium``.
+
.. note::
The training can take a long time (usually a couple of days).
@@ -95,8 +109,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models
-------------
-Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
-``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
+``vits-epoch-*.onnx``.
.. code-block:: bash
@@ -120,4 +134,68 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- - ``_
+ - ``--model-type=high``: ``_
+ - ``--model-type=medium``: ``_
+ - ``--model-type=low``: ``_
+
+Usage in sherpa-onnx
+--------------------
+
+The following describes how to test the exported ONNX model in `sherpa-onnx`_.
+
+.. hint::
+
+ `sherpa-onnx`_ supports different programming languages, e.g., C++, C, Python,
+ Kotlin, Java, Swift, Go, C#, etc. It also supports Android and iOS.
+
+ We only describe how to use pre-built binaries from `sherpa-onnx`_ below.
+ Please refer to ``_
+ for more documentation.
+
+Install sherpa-onnx
+^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ pip install sherpa-onnx
+
+To check that you have installed `sherpa-onnx`_ successfully, please run:
+
+.. code-block:: bash
+
+ which sherpa-onnx-offline-tts
+ sherpa-onnx-offline-tts --help
+
+Download lexicon files
+^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ cd /tmp
+ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2
+ tar xf espeak-ng-data.tar.bz2
+
+Run sherpa-onnx
+^^^^^^^^^^^^^^^
+
+.. code-block:: bash
+
+ cd egs/ljspeech/TTS
+
+ sherpa-onnx-offline-tts \
+ --vits-model=vits/exp/vits-epoch-1000.onnx \
+ --vits-tokens=data/tokens.txt \
+ --vits-data-dir=/tmp/espeak-ng-data \
+ --num-threads=1 \
+ --output-filename=./high.wav \
+ "Ask not what your country can do for you; ask what you can do for your country."
+
+.. hint::
+
+ You can also use ``sherpa-onnx-offline-tts-play`` to play the audio
+ as it is generating.
+
+You should get a file ``high.wav`` after running the above command.
+
+Congratulations! You have successfully trained and exported a text-to-speech
+model and run it with `sherpa-onnx`_.
diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst
index 34024a5ea..45ae9d9d2 100644
--- a/docs/source/recipes/TTS/vctk/vits.rst
+++ b/docs/source/recipes/TTS/vctk/vits.rst
@@ -1,11 +1,11 @@
-VITS
+VITS-VCTK
===============
This tutorial shows you how to train an VITS model
with the `VCTK `_ dataset.
.. note::
-
+
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 8df61f0d0..52795d452 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
Streaming-ASR/index
RNN-LM/index
TTS/index
+ Finetune/index
diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh
index 40ee2eb97..09dfd5fac 100755
--- a/egs/aidatatang_200zh/ASR/prepare.sh
+++ b/egs/aidatatang_200zh/ASR/prepare.sh
@@ -16,8 +16,8 @@ perturb_speed=true
#
# - $dl_dir/aidatatang_200zh
# You can find "corpus" and "transcript" inside it.
-# You can download it at
-# https://openslr.org/62/
+# You can download it at https://openslr.org/62/
+# If you download the data by yourself, DON'T FORGET to extract the *.tar.gz files under corpus.
dl_dir=$PWD/download
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index d491996b2..e29dd8ab5 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -288,8 +288,9 @@ class Aidatatang_200zhAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
- buffer_size=50000,
)
else:
logging.info("Using SimpleCutSampler.")
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
old mode 100644
new mode 100755
index e348f7b2b..5179bfa1c
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@@ -20,7 +21,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 29 \
--avg 19
@@ -45,12 +46,13 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -85,10 +87,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -122,10 +124,14 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
+ # Load tokens.txt here
+ token_table = k2.SymbolTable.from_file(params.tokens)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ # Load id of the token and the vocab size
+ # is defined in local/train_bpe_model.py
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
@@ -152,6 +158,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md
index 176f065e5..d088072a7 100644
--- a/egs/aishell/ASR/README.md
+++ b/egs/aishell/ASR/README.md
@@ -19,8 +19,17 @@ The following table lists the differences among them.
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
-| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size 1 |
+| `pruned_transducer_stateless7` | Zipformer | Embedding | pruned RNN-T + zipformer encoder + stateless decoder with context-size set to 1 |
+| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe with context-size set to 1 |
+
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.
+
+# Whisper
+
+Recipe to finetune large pretrained models
+| | Encoder | Decoder | Comment |
+|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
+| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 0b22f41a1..355d1516d 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -1,10 +1,120 @@
## Results
+### Aishell training results (Fine-tuning Pretrained Models)
+#### Whisper
+[./whisper](./whisper)
+##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3
+
+| | test (before fine-tuning) | test (after fine-tuning) | comment |
+|------------------------|------|------|-----------------------------------------|
+| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp |
+| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 |
+| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 |
+
+Command for training is:
+```bash
+pip install -r whisper/requirements.txt
+
+./prepare.sh --stage 30 --stop_stage 30
+
+#fine-tuning with deepspeed zero stage 1
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --deepspeed \
+ --deepspeed_config ./whisper/ds_config_zero1.json
+
+# fine-tuning with ddp
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_medium \
+ --base-lr 1e-5 \
+ --model-name medium
+```
+
+Command for decoding using fine-tuned models:
+```bash
+git lfs install
+git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
+ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch 999 --avg 1 \
+ --beam-size 10 --max-duration 50
+```
+Command for decoding using pretrained models (before fine-tuning):
+```bash
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch -1 --avg 1 \
+ --remove-whisper-encoder-input-length-restriction False \
+ --beam-size 10 --max-duration 50
+```
+Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
+are available at
+
+
### Aishell training result (Stateless Transducer)
+#### Zipformer (Byte-level BPE)
+
+[./zipformer](./zipformer/)
+
+It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500.
+
+##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
+
+| | test | dev | comment |
+|------------------------|------|------|-----------------------------------------|
+| greedy search | 4.54 | 4.31 | --epoch 40 --avg 10 |
+| modified beam search | 4.37 | 4.11 | --epoch 40 --avg 10 |
+| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
+
+```bash
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./zipformer/train_bbpe.py \
+ --world-size 2 \
+ --num-epochs 40 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --context-size 2 \
+ --enable-musan 0 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 1000 \
+ --enable-musan 0 \
+ --base-lr 0.045 \
+ --lr-batches 7500 \
+ --lr-epochs 10 \
+ --spec-aug-time-warp-factor 20
+```
+
+Command for decoding is:
+```bash
+for m in greedy_search modified_beam_search fast_beam_search ; do
+ ./zipformer/decode_bbpe.py \
+ --epoch 40 \
+ --avg 10 \
+ --exp-dir ./zipformer_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --context-size 2 \
+ --decoding-method $m
+done
+```
+Pretrained models, training logs, decoding logs, tensorboard and decoding results
+are available at
+
+
+
#### Zipformer (Non-streaming)
-[./zipformer](./zipformer)
+[./zipformer](./zipformer/)
It's reworked Zipformer with Pruned RNNT loss.
**Caution**: It uses `--context-size=1`.
@@ -19,7 +129,7 @@ It's reworked Zipformer with Pruned RNNT loss.
Command for training is:
```bash
-./prepare.sh
+./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1"
@@ -84,7 +194,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
- --max-duration 1200
+ --max-duration 1200
```
Command for decoding is:
@@ -134,7 +244,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
- --max-duration 800
+ --max-duration 800
```
Command for decoding is:
@@ -150,7 +260,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
- --encoder-unmasked-dim 192,192,256,320,256,192
+ --encoder-unmasked-dim 192,192,256,320,256,192
done
```
@@ -260,7 +370,7 @@ done
Pretrained models, training logs, decoding logs, and decoding results
are available at
-#### Pruned transducer stateless 7 (zipformer)
+#### Pruned transducer stateless 7 (Byte-level BPE)
See
@@ -703,7 +813,6 @@ python3 ./transducer_stateless/decode.py \
--max-sym-per-frame 3
```
-### Aishell training results (Transducer-stateless)
#### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at
And pretrained model is available at
diff --git a/egs/aishell/ASR/conformer_ctc/README.md b/egs/aishell/ASR/conformer_ctc/README.md
index 50596ee92..41637159d 100644
--- a/egs/aishell/ASR/conformer_ctc/README.md
+++ b/egs/aishell/ASR/conformer_ctc/README.md
@@ -1,4 +1,4 @@
Please visit
-
+
for how to run this recipe.
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index 74a7b5933..2cb476e20 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -419,7 +419,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@@ -432,7 +432,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 20a855e7f..8a2daa93e 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -431,7 +431,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@@ -444,7 +444,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index c7000da1c..3c48f0aa1 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@@ -42,9 +49,14 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
+def compute_fbank_aishell(
+ num_mel_bins: int = 80,
+ perturb_speed: bool = False,
+ whisper_fbank: bool = False,
+ output_dir: str = "data/fbank",
+):
src_dir = Path("data/manifests")
- output_dir = Path("data/fbank")
+ output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
dataset_parts = (
@@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)
-
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
- logging.info(f"Doing speed perturb")
+ logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@@ -111,6 +127,18 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory. Default: data/fbank.",
+ )
return parser.parse_args()
@@ -121,5 +149,8 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell(
- num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
+ output_dir=args.output_dir,
)
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index 9f73a2073..13be69534 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -360,7 +360,7 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 11: Train RNN LM model"
+ log "Stage 12: Train RNN LM model"
python ../../../icefall/rnn_lm/train.py \
--start-epoch 0 \
--world-size 1 \
@@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
--vocab-size 4336 \
--master-port 12345
fi
+
+# whisper large-v3 using 128 mel bins, others using 80 mel bins
+whisper_mel_bins=80
+output_dir=data/fbank_whisper
+if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
+ log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
+ if [ ! -f $output_dir/.aishell.whisper.done ]; then
+ mkdir -p $output_dir
+ ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
+ ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
+ touch $output_dir/.aishell.whisper.done
+ fi
+fi
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index fb6c7c481..f41ea6776 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -390,7 +390,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -402,7 +402,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 2ce5cfe69..c2dc0d5f3 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -106,10 +106,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -136,10 +136,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index 27c64efaa..3901a330c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -526,7 +526,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -538,7 +538,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 723414167..2248c7a08 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -47,6 +47,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -123,10 +123,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -153,10 +153,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
params.datatang_prob = 0
logging.info(params)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
index 696eea906..d50bccf82 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
@@ -444,7 +444,7 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
- store_transcripts(filename=recog_path, texts=results_char)
+ store_transcripts(filename=recog_path, texts=results_char, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -452,7 +452,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
index 6027273b2..058d0ff6b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error()
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
index 39d988cd0..4981fb71a 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
@@ -49,14 +49,14 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
-import sentencepiece as spm
import torch
import torch.nn as nn
from decoder2 import Decoder
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
@@ -65,8 +65,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -123,12 +122,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- help="""The lang dir
- It contains language related input files such as
- "lexicon.txt"
- """,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -404,9 +401,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
index 9d9dd4288..2dc835f3b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
@@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
index da9000164..46f542641 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
@@ -581,7 +581,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -594,7 +594,11 @@ def save_results(
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
index 3858bafd7..811269989 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
index 99110d6b6..61b929091 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -250,7 +250,7 @@ def get_parser():
parser.add_argument(
"--context-size",
type=int,
- default=1,
+ default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
@@ -492,7 +492,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -500,7 +500,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 0fba3b58f..6653d9d9c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -882,9 +883,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
index 2e1044658..f3b0f1e11 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 6abe6c084..aacbd153d 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -275,6 +275,8 @@ class AishellAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 824ca2a92..05e52f560 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -278,7 +278,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -289,7 +289,13 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
- wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
+ wer = write_error_stats(
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
+ )
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index d23f4f883..d958a6338 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -327,7 +327,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@@ -338,7 +338,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 01de5d772..bfd0ecb0c 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -23,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
@@ -47,6 +47,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -56,8 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -92,10 +92,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -192,10 +192,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index cd8dd821c..ed453afd2 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -226,6 +226,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index d164b6890..57f7a8239 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -372,7 +372,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -384,7 +384,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index c1081c32b..4f2c71d18 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -46,6 +46,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -56,7 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -99,10 +100,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -190,10 +191,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 0a7d87fe8..56f3724eb 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -376,7 +376,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -388,7 +388,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index 3e14ad69c..487748947 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -46,6 +46,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -55,8 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -99,10 +99,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -190,10 +190,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/whisper/asr_datamodule.py b/egs/aishell/ASR/whisper/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/aishell/ASR/whisper/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py
new file mode 100755
index 000000000..5350cb2b0
--- /dev/null
+++ b/egs/aishell/ASR/whisper/decode.py
@@ -0,0 +1,507 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+# Fangjun Kuang,
+# Wei Kang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# Command for decoding using fine-tuned models:
+git lfs install
+git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
+ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch 999 --avg 1 \
+ --manifest-dir data/fbank_whisper \
+ --beam-size 10 --max-duration 50
+
+# Command for decoding using pretrained models (before fine-tuning):
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch -1 --avg 1 \
+ --manifest-dir data/fbank_whisper \
+ --remove-whisper-encoder-input-length-restriction False \
+ --beam-size 10 --max-duration 50
+
+"""
+
+import argparse
+import logging
+import re
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+import whisper
+from asr_datamodule import AishellAsrDataModule
+from tn.chinese.normalizer import Normalizer
+from whisper.normalizers import BasicTextNormalizer
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+from zhconv import convert
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
+from icefall.env import get_env_info
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def average_checkpoints(
+ filenames: List[Path], device: torch.device = torch.device("cpu")
+) -> dict:
+ """Average a list of checkpoints.
+ The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
+
+ Args:
+ filenames:
+ Filenames of the checkpoints to be averaged. We assume all
+ checkpoints are saved by :func:`save_checkpoint`.
+ device:
+ Move checkpoints to this device before averaging.
+ Returns:
+ Return a dict (i.e., state_dict) which is the average of all
+ model state dicts contained in the checkpoints.
+ """
+ n = len(filenames)
+
+ if "model" in torch.load(filenames[0], map_location=device):
+ avg = torch.load(filenames[0], map_location=device)["model"]
+ else:
+ avg = torch.load(filenames[0], map_location=device)
+
+ # Identify shared parameters. Two parameters are said to be shared
+ # if they have the same data_ptr
+ uniqued: Dict[int, str] = dict()
+
+ for k, v in avg.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+
+ for i in range(1, n):
+ if "model" in torch.load(filenames[i], map_location=device):
+ state_dict = torch.load(filenames[i], map_location=device)["model"]
+ else:
+ state_dict = torch.load(filenames[i], map_location=device)
+ for k in uniqued_names:
+ avg[k] += state_dict[k]
+
+ for k in uniqued_names:
+ if avg[k].is_floating_point():
+ avg[k] /= n
+ else:
+ avg[k] //= n
+
+ return avg
+
+
+def remove_punctuation(text: str or List[str]):
+ """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
+
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings without any punctuation.
+ """
+ punctuation = "!,.;:?、!,。;:?《》 "
+ if isinstance(text, str):
+ text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
+ return text
+ elif isinstance(text, list):
+ result_text = []
+ for t in text:
+ t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
+ result_text.append(t)
+ return result_text
+ else:
+ raise Exception(f"Not support type {type(text)}")
+
+
+def to_simple(text: str or List[str]):
+ """Convert traditional Chinese to simplified Chinese.
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings converted to simplified Chinese.
+ """
+ if isinstance(text, str):
+ text = convert(text, "zh-cn")
+ return text
+ elif isinstance(text, list):
+ result_text = []
+ for t in text:
+ t = convert(t, "zh-cn")
+ result_text.append(t)
+ return result_text
+ else:
+ raise Exception(f"Not support type{type(text)}")
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=-1,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="beam-search",
+ help="""Decoding method.
+ Supported values are:
+ - beam-search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=1,
+ help="beam size for beam search decoding",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="whisper/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="large-v2",
+ choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
+ help="""The model name to use.
+ """,
+ )
+
+ parser.add_argument(
+ "--remove-whisper-encoder-input-length-restriction",
+ type=str2bool,
+ default=True,
+ help="replace whisper encoder forward method to remove input length restriction",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ batch: dict,
+) -> Dict[str, List[List[int]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: "beam-search"
+ - value: A list of lists. Each sublist is a list of token IDs.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+ dtype = torch.float16
+ device = torch.device("cuda")
+
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device, dtype=dtype).transpose(1, 2)
+ if not params.remove_whisper_encoder_input_length_restriction:
+ T = 3000
+ if feature.shape[2] < T:
+ feature = torch.cat(
+ [
+ feature,
+ torch.zeros(
+ feature.shape[0], feature.shape[1], T - feature.shape[2]
+ ).to(device, dtype=dtype),
+ ],
+ 2,
+ )
+
+ supervisions = batch["supervisions"]
+ feature_len = supervisions["num_frames"]
+ feature_len = feature_len.to(device, dtype=dtype)
+ results = model.decode(feature, params.decoding_options)
+ hyps = [result.text for result in results]
+
+ hyps = remove_punctuation(hyps)
+ hyps = to_simple(hyps)
+ hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
+
+ return {"beam-search": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ The dataloader.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ batch=batch,
+ )
+
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+ setup_logger(
+ f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
+ )
+
+ options = whisper.DecodingOptions(
+ task="transcribe",
+ language="zh",
+ without_timestamps=True,
+ beam_size=params.beam_size,
+ )
+ params.decoding_options = options
+ params.cleaner = BasicTextNormalizer()
+ params.normalizer = Normalizer()
+
+ logging.info("Decoding started")
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+
+ logging.info(f"device: {device}")
+
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
+ if params.epoch > 0:
+ if params.avg > 1:
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ if "model" not in checkpoint:
+ # deepspeed converted checkpoint only contains model state_dict
+ filenames = [
+ f"{params.exp_dir}/epoch-{epoch}.pt"
+ for epoch in range(start, params.epoch + 1)
+ ]
+ model.load_state_dict(average_checkpoints(filenames))
+ else:
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ # save checkpoints
+ filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save(model.state_dict(), filename)
+ else:
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ if "model" not in checkpoint:
+ model.load_state_dict(checkpoint, strict=True)
+ else:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ model.to(device)
+ model.eval()
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+ test_dl = aishell.test_dataloaders(aishell.test_cuts())
+ test_sets = ["valid", "test"]
+ test_dls = [valid_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ )
+
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json
new file mode 100644
index 000000000..bf8cc0452
--- /dev/null
+++ b/egs/aishell/ASR/whisper/ds_config_zero1.json
@@ -0,0 +1,38 @@
+{
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 100,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 0.01
+ },
+ "zero_optimization": {
+ "stage": 1,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-5
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": 0,
+ "warmup_max_lr": 1e-5,
+ "warmup_num_steps": 100
+ }
+ },
+ "gradient_accumulation_steps": 1,
+ "gradient_clipping": 5,
+ "steps_per_print": 50,
+ "train_micro_batch_size_per_gpu": 1,
+ "wall_clock_breakdown": false
+}
diff --git a/egs/aishell/ASR/whisper/label_smoothing.py b/egs/aishell/ASR/whisper/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/aishell/ASR/whisper/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/aishell/ASR/whisper/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt
new file mode 100755
index 000000000..0708f2344
--- /dev/null
+++ b/egs/aishell/ASR/whisper/requirements.txt
@@ -0,0 +1,10 @@
+k2
+kaldialign
+git+https://github.com/lhotse-speech/lhotse
+sentencepiece
+tensorboard
+librosa
+git+https://github.com/yuekaizhang/whisper.git
+zhconv
+WeTextProcessing
+deepspeed
diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py
new file mode 100755
index 000000000..d77f8c270
--- /dev/null
+++ b/egs/aishell/ASR/whisper/train.py
@@ -0,0 +1,927 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+#fine-tuning with deepspeed zero stage 1
+torchrun --nproc_per_node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --manifest-dir data/fbank_whisper \
+ --deepspeed \
+ --deepspeed_config ./whisper/ds_config_zero1.json
+
+# fine-tuning with ddp
+torchrun --nproc_per_node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_medium \
+ --manifest-dir data/fbank_whisper \
+ --base-lr 1e-5 \
+ --model-name medium
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import deepspeed
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import whisper
+from asr_datamodule import AishellAsrDataModule
+from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+from label_smoothing import LabelSmoothingLoss
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.functional import pad as pad_tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import update_averaged_model
+from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=10,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="whisper/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="large-v2",
+ choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
+ help="""The model name to use.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=1e-5, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser = deepspeed.add_config_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - frame_shift_ms: The frame shift in milliseconds.
+ - allowed_excess_duration_ratio: The allowed excess duration ratio.
+ - best_train_loss: The best training loss so far.
+ - best_valid_loss: The best validation loss so far.
+ - best_train_epoch: The epoch where the best training loss is achieved.
+ - best_valid_epoch: The epoch where the best validation loss is achieved.
+ - batch_idx_train: The batch index of the current batch.
+ - log_interval: Log training stats every `log_interval` batches.
+ - reset_interval: Reset the stats every `reset_interval` batches.
+ - valid_interval: Run validation every `valid_interval` batches.
+ - env_info: The environment information.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "subsampling_factor": 2,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 5000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute the loss for the given batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ tokenizer:
+ The tokenizer used to encode the text.
+ model:
+ The model for training.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ Whether it is training.
+ Returns:
+ Return a tuple of two elements. The first element is the loss tensor.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+
+ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
+ padding_size = max(tensor.shape[0] for tensor in tensors)
+ dims = len(tensors[0].shape)
+ padded_tensors = []
+ for tensor in tensors:
+ padding = [0] * 2 * dims
+ padding[-1] = padding_size - tensor.shape[0]
+ padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
+ return torch.stack([tensor for tensor in padded_tensors], dim=0)
+
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ feature = feature.transpose(1, 2) # (N, C, T)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+
+ texts = batch["supervisions"]["text"]
+ # remove spaces in texts
+ texts = [text.replace(" ", "") for text in texts]
+
+ text_tokens_list = [
+ list(tokenizer.sot_sequence_including_notimestamps)
+ + tokenizer.encode(text)
+ + [tokenizer.eot]
+ for text in texts
+ ]
+ # convert it to torch tensor
+ text_tokens_list = [
+ torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
+ ]
+
+ # 50256 is the index of for all whisper models
+ prev_outputs_tokens = _batch_tensors(
+ [tokens[:-1] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_tokens = _batch_tensors(
+ [tokens[1:] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_lengths = torch.LongTensor(
+ [tokens.shape[0] - 1 for tokens in text_tokens_list]
+ )
+
+ decoder_criterion = LabelSmoothingLoss(
+ ignore_index=50256, label_smoothing=0.1, reduction="sum"
+ )
+
+ # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
+ ignore_prefix_size = 3
+ with torch.set_grad_enabled(is_training):
+ encoder_out = model.encoder(feature)
+ text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
+ text_logits = text_logits[:, ignore_prefix_size:, :]
+ target_tokens = target_tokens[:, ignore_prefix_size:]
+ loss = decoder_criterion(text_logits, target_tokens.to(device))
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ if params.deepspeed:
+ # deepspeed's backward() is different from torch's backward()
+ # in that it does not accept a loss tensor as input.
+ # It computes the loss internally.
+ model.backward(loss)
+ model.step()
+ else:
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ and not params.deepspeed
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+ if batch_idx % params.log_interval == 0:
+ try:
+ cur_lr = scheduler.get_last_lr()[0]
+ except: # noqa
+ cur_lr = 0.0
+ cur_grad_scale = (
+ scaler._scale.item()
+ if (params.use_fp16 and not params.deepspeed)
+ else 1.0
+ )
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (
+ f"grad_scale: {scaler._scale.item()}"
+ if (params.use_fp16 and not params.deepspeed)
+ else ""
+ )
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
+ del model.alignment_heads
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ tokenizer = whisper.tokenizer.get_tokenizer(
+ model.is_multilingual,
+ num_languages=model.num_languages,
+ language="zh",
+ task="transcribe",
+ )
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if world_size > 1:
+ if params.deepspeed:
+ logging.info("Using DeepSpeed")
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=params, model=model, model_parameters=model.parameters()
+ )
+ else:
+ logging.info("Using DDP")
+ setup_dist(use_ddp_launch=True)
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(aishell.train_cuts())
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ if not params.deepspeed:
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if params.deepspeed:
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"epoch-{params.cur_epoch}",
+ client_state={},
+ )
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
+ tag=f"epoch-{params.cur_epoch}",
+ )
+ else:
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1 and not params.deepspeed:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
new file mode 100644
index 000000000..5bfbdce3b
--- /dev/null
+++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn.functional as F
+import whisper
+
+
+def forward(self, x: torch.Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+def replace_whisper_encoder_forward():
+ """
+ This function monkey patches the forward method of the whisper encoder.
+ To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
+ """
+ whisper.model.AudioEncoder.forward = forward
diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py
index 1968904ae..538189e52 100755
--- a/egs/aishell/ASR/zipformer/decode.py
+++ b/egs/aishell/ASR/zipformer/decode.py
@@ -560,7 +560,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -570,7 +570,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py
new file mode 100755
index 000000000..1ec10b059
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/decode_bbpe.py
@@ -0,0 +1,840 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Mingshuang Luo,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search (trivial_graph)
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(4) fast beam search (LG)
+./zipformer/decode_bbpe.py \
+ --epoch 30 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest oracle WER)
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_bbpe/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the byte BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bbpe_500/",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([lexicon.word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ ref_texts = []
+ for tx in supervisions["text"]:
+ ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
+
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(ref_texts),
+ nbest_scale=params.nbest_scale,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search_" + key: hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key += f"_beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ilme_scale_{params.ilme_scale}"
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}_" + key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ lexicon:
+ directory containing the lexicon.
+ sp:
+ SentencePiece model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = "".join(ref_text.split())
+
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest_oracle",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"_ilme_scale_{params.ilme_scale}"
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ lexicon = Lexicon(params.lang_dir)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if "LG" in params.decoding_method:
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ dev_cuts = aishell.valid_cuts()
+ dev_cuts = dev_cuts.filter(remove_short_utt)
+ dev_dl = aishell.valid_dataloaders(dev_cuts)
+
+ test_cuts = aishell.test_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ test_dl = aishell.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py
new file mode 100755
index 000000000..cd16284f7
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer_bbpe/exp \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+Usage of this script:
+
+./zipformer/jit_pretrained.py \
+ --nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model-filename",
+ type=str,
+ required=True,
+ help="Path to the torchscript model cpu_jit.pt",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ required=True,
+ help="""Path to the bbpe.model.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+def greedy_search(
+ model: torch.jit.ScriptModule,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ A 3-D tensor of shape (N, T, C)
+ encoder_out_lens:
+ A 1-D tensor of shape (N,).
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ assert encoder_out.ndim == 3
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ device = encoder_out.device
+ blank_id = model.decoder.blank_id
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ context_size = model.decoder.context_size
+ hyps = [[blank_id] * context_size for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ device=device,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ ).squeeze(1)
+
+ offset = 0
+ for batch_size in batch_size_list:
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = packed_encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out
+ # current_encoder_out's shape: (batch_size, encoder_out_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ )
+ # logits'shape (batch_size, vocab_size)
+
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v != blank_id:
+ hyps[i].append(v)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ device=device,
+ dtype=torch.int64,
+ )
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ )
+ decoder_out = decoder_out.squeeze(1)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = torch.jit.load(args.nn_model_filename)
+
+ model.eval()
+
+ model.to(device)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(args.bpe_model)
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ features=features,
+ feature_lengths=feature_lengths,
+ )
+
+ hyps = greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+
+ s = "\n"
+ for filename, hyp in zip(args.sound_files, hyps):
+ words = smart_byte_decode(sp.decode(hyp))
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py
new file mode 100755
index 000000000..387bef98a
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/pretrained_bbpe.py
@@ -0,0 +1,403 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp_bbpe \
+ --tokens ./data/lang_bbpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp_bbpe \
+ --causal 1 \
+ --tokens ./data/lang_bbpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+Usage of this script:
+
+- For non-streaming model:
+
+(1) greedy search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+- For streaming model:
+
+(1) greedy search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+You can also use `./zipformer/exp_bbpe/epoch-xx.pt`.
+
+Note: ./zipformer/exp_bbpe/pretrained.pt is generated by ./zipformer/export_bbpe.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ required=True,
+ help="""Path to the bbpe.model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+
+ logging.info("Creating model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ # model forward
+ encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.method}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py
index d381649e4..a25979226 100755
--- a/egs/aishell/ASR/zipformer/train.py
+++ b/egs/aishell/ASR/zipformer/train.py
@@ -86,6 +86,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py
new file mode 100755
index 000000000..0713c5787
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/train_bbpe.py
@@ -0,0 +1,941 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./zipformer/train_bbpe.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 350
+
+# For mix precision training:
+
+./zipformer/train_bbpe.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 750
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from train import (
+ LRSchedulerType,
+ add_model_arguments,
+ get_adjusted_batch_count,
+ get_model,
+ get_params,
+ load_checkpoint_if_available,
+ save_checkpoint,
+ set_batch_count,
+)
+
+from icefall import byte_encode, diagnostics
+from icefall.checkpoint import remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+ tokenize_by_CJK_char,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_bbpe/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the Byte BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="""The prune range for rnnt loss, it means how many symbols(context)
+ we are using to compute the loss""",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="""The scale to smooth the loss with lm
+ (output of prediction network) part.""",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="""The scale to smooth the loss with am (output of encoder network) part.""",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="""To get pruning ranges, we will calculate a simple version
+ loss(joiner is just addition), this simple loss also uses for
+ training (as a regularization item). We will scale the simple loss
+ with this parameter before adding to the final loss.""",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, _ = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ train_cuts = aishell.train_cuts()
+ valid_cuts = aishell.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 15 seconds
+ #
+ # Caution: There is a reason to select 15.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 15.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_CJK_char(text))
+ c.supervisions[0].text = text
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
+
+ valid_cuts = valid_cuts.map(tokenize_and_encode_text)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_dl = aishell.valid_dataloaders(valid_cuts)
+
+ if False and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The sentence piece model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md
index 32ad74b50..0b7ae9299 100644
--- a/egs/aishell2/ASR/RESULTS.md
+++ b/egs/aishell2/ASR/RESULTS.md
@@ -1,6 +1,6 @@
## Results
-### Aishell2 char-based training results
+### Aishell2 char-based training results
#### Pruned transducer stateless 5
diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
index 1fb1621ff..557f22b0c 100755
--- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py
+++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py
@@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
+def compute_fbank_aishell2(
+ num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
+):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
- num_jobs = min(15, os.cpu_count())
+ num_jobs = min(8, os.cpu_count())
dataset_parts = (
"train",
@@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)
-
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
- logging.info(f"Doing speed perturb")
+ logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@@ -111,7 +124,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
-
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
return parser.parse_args()
@@ -122,5 +140,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell2(
- num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
)
diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh
index a5eb9bd13..c959bd4d1 100755
--- a/egs/aishell2/ASR/prepare.sh
+++ b/egs/aishell2/ASR/prepare.sh
@@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
fi
+whisper_mel_bins=80
+if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
+ log "Stage 30: Compute whisper fbank for aishell2"
+ if [ ! -f data/fbank/.aishell2.whisper.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
+ touch data/fbank/.aishell2.whisper.done
+ fi
+fi
+
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 8f6a88f59..f9cdfb621 100644
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -296,6 +296,8 @@ class AiShell2AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index 8a5be94d0..c92c7ab83 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
- --lang-dir data/lang_char
+ --tokens ./data/lang_char/tokens.txt \
--epoch 25 \
--avg 5
@@ -48,6 +48,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -115,10 +115,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -154,10 +154,10 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = lexicon.token_table[""]
- params.unk_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell4/ASR/README.md b/egs/aishell4/ASR/README.md
index 67fa17790..b96161762 100644
--- a/egs/aishell4/ASR/README.md
+++ b/egs/aishell4/ASR/README.md
@@ -3,7 +3,7 @@
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
-The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
+The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
(From [Open Speech and Language Resources](https://www.openslr.org/111/))
diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
index f19163988..b5f8468ac 100755
--- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py
+++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py
@@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
-from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
+def compute_fbank_aishell4(
+ num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
+):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
- num_jobs = min(15, os.cpu_count())
+ num_jobs = min(8, os.cpu_count())
dataset_parts = (
"train_S",
@@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
dataset_parts,
)
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
- logging.info(f"Doing speed perturb")
+ logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
- storage_type=ChunkedLilcomHdf5Writer,
+ storage_type=LilcomChunkyWriter,
)
logging.info("About splitting cuts into smaller chunks")
@@ -121,7 +135,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
-
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
return parser.parse_args()
@@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell4(
- num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
)
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
index e8d9eb7b9..38a36d97a 100755
--- a/egs/aishell4/ASR/prepare.sh
+++ b/egs/aishell4/ASR/prepare.sh
@@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
-stop_stage=100
+stop_stage=7
perturb_speed=true
@@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
- log "Stage 2: Process aishell4"
+ log "Stage 2: Compute fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
- mkdir -p data/fbank/aishell4
+ mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
- touch data/fbank/aishell4/.fbank.done
+ touch data/fbank/.fbank.done
+ fi
+fi
+
+whisper_mel_bins=80
+if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
+ log "Stage 20: Compute whisper fbank for aishell4"
+ if [ ! -f data/fbank/aishell4/.fbank.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
+ touch data/fbank/.fbank.done
fi
fi
@@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
- log "Stage 5: Compute fbank for aishell4"
- if [ ! -f data/fbank/.aishell4.done ]; then
- mkdir -p data/fbank
- ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
- touch data/fbank/.aishell4.done
- fi
-fi
-
-if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
- log "Stage 6: Prepare char based lang"
+ log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index e6db2651f..c10456da5 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -306,7 +306,8 @@ class Aishell4AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=100000,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index bf9856c60..246820833 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -48,6 +48,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -115,13 +115,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="""The lang dir
- It contains language related input files such as
- "lexicon.txt"
- """,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -157,9 +154,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
index f8c10648a..09c873a34 100755
--- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
+++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py
@@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
+def compute_fbank_alimeeting(
+ num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
+):
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
- num_jobs = min(15, os.cpu_count())
+ num_jobs = min(8, os.cpu_count())
dataset_parts = (
"train",
@@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
"test",
)
- prefix = "alimeeting"
+ prefix = "alimeeting-far"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
@@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
dataset_parts,
)
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
- logging.info(f"Doing speed perturb")
+ logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@@ -121,7 +135,12 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
-
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use the Whisper Fbank feature extractor. Default: False.",
+ )
return parser.parse_args()
@@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_alimeeting(
- num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
)
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index c8fed658d..996a1da2d 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
-stop_stage=100
+stop_stage=7
perturb_speed=true
# We assume dl_dir (download dir) contains the following
@@ -15,7 +15,7 @@ perturb_speed=true
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
-# https://openslr.org/62/
+# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz
@@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
- log "Stage 2: Process alimeeting"
- if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
- mkdir -p data/fbank/alimeeting
+ log "Stage 2: compute fbank for alimeeting"
+ if [ ! -f data/fbank/.fbank.done ]; then
+ mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
+ touch data/fbank/.fbank.done
+ fi
+fi
+
+whisper_mel_bins=80
+if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
+ log "Stage 20: compute whisper fbank for alimeeting"
+ if [ ! -f data/fbank/.fbank.done ]; then
+ mkdir -p data/fbank
+ ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
+ touch data/fbank/.fbank.done
fi
fi
@@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
- log "Stage 5: Compute fbank for alimeeting"
- if [ ! -f data/fbank/.alimeeting.done ]; then
- mkdir -p data/fbank
- ./local/compute_fbank_alimeeting.py --perturb-speed True
- touch data/fbank/.alimeeting.done
- fi
-fi
-
-if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
- log "Stage 6: Prepare char based lang"
+ log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 5ad80817a..410741215 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -288,7 +288,8 @@ class AlimeetingAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=30000,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 8e5cc6075..5dc73c52b 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -20,7 +20,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 29 \
--avg 18
@@ -45,12 +45,12 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -85,10 +85,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -122,10 +122,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh
index 1098840f8..15c20692d 100755
--- a/egs/alimeeting/ASR_v2/prepare.sh
+++ b/egs/alimeeting/ASR_v2/prepare.sh
@@ -12,7 +12,7 @@ use_gss=true # Use GSS-based enhancement with MDM setting
#
# - $dl_dir/alimeeting
# This directory contains the following files downloaded from
-# https://openslr.org/62/
+# https://openslr.org/119/
#
# - Train_Ali_far.tar.gz
# - Train_Ali_near.tar.gz
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
index 9d288218a..6b56c8a6a 100644
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
@@ -263,6 +263,8 @@ class AlimeetingAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
logging.info("About to create train dataloader")
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
index 23a88dd29..8bafaef44 100755
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
@@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
@@ -86,9 +86,8 @@ import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
-import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@@ -98,8 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -156,10 +154,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -199,10 +197,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
index 8f09f1aa5..30879d8d2 100755
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
@@ -70,6 +70,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -851,9 +852,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
index 79474f1d8..554facfc1 100644
--- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -269,6 +269,8 @@ class AmiAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
logging.info("About to create train dataloader")
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
index 9b67141c0..d62cdadb7 100755
--- a/egs/ami/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -69,6 +69,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -842,9 +843,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
index 1549c1631..ea8b62242 100644
--- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
+++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
@@ -254,6 +254,8 @@ class AmiAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py
index cd5fafc34..adc6a8495 100755
--- a/egs/ami/SURT/dprnn_zipformer/train.py
+++ b/egs/ami/SURT/dprnn_zipformer/train.py
@@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1138,9 +1139,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py
index 9f3b4425f..ac5b0dadc 100755
--- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py
+++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py
@@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1129,9 +1130,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/audioset/AT/README.md b/egs/audioset/AT/README.md
new file mode 100644
index 000000000..2506d41e5
--- /dev/null
+++ b/egs/audioset/AT/README.md
@@ -0,0 +1,12 @@
+# Introduction
+
+This is an audio tagging recipe for [Audioset](https://research.google.com/audioset/#/). It aims at predicting the sound events of an audio clip.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+
+# Zipformer
+
+| Encoder | Feature type |
+| --------| -------------|
+| Zipformer | Frame level fbank|
diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md
new file mode 100644
index 000000000..0128b7018
--- /dev/null
+++ b/egs/audioset/AT/RESULTS.md
@@ -0,0 +1,95 @@
+## Results
+
+### zipformer
+See for more details
+
+[zipformer](./zipformer)
+
+#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+
+
+The model achieves the following mean averaged precision on AudioSet:
+
+| Model | mAP |
+| ------ | ------- |
+| Zipformer-AT | 45.1 |
+
+The training command is:
+
+```bash
+export CUDA_VISIBLE_DEVICES="4,5,6,7"
+subset=full
+
+python zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 50 \
+ --exp-dir zipformer/exp_at_as_${subset} \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --num-events 527 \
+ --audioset-subset $subset \
+ --max-duration 1000 \
+ --enable-musan True \
+ --master-port 13455
+```
+
+The evaluation command is:
+
+```bash
+python zipformer/evaluate.py \
+ --epoch 32 \
+ --avg 8 \
+ --exp-dir zipformer/exp_at_as_full \
+ --max-duration 500
+```
+
+
+#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M
+
+You can find a pretrained model, training logs, decoding logs, and decoding results at:
+
+
+The model achieves the following mean averaged precision on AudioSet:
+
+| Model | mAP |
+| ------ | ------- |
+| Zipformer-S-AT | 45.1 |
+
+The training command is:
+
+```bash
+export CUDA_VISIBLE_DEVICES="4,5,6,7"
+subset=full
+
+python zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 50 \
+ --exp-dir zipformer/exp_small_at_as_${subset} \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --num-events 527 \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192 \
+ --audioset-subset $subset \
+ --max-duration 1200 \
+ --enable-musan True \
+ --master-port 13455
+```
+
+The evaluation command is:
+
+```bash
+python zipformer/evaluate.py \
+ --epoch 31 \
+ --avg 4 \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192 \
+ --exp-dir zipformer/exp_small_at_as_full \
+ --max-duration 500
+```
\ No newline at end of file
diff --git a/egs/audioset/AT/local/compute_fbank_musan.py b/egs/audioset/AT/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/audioset/AT/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/audioset/AT/local/generate_audioset_manifest.py b/egs/audioset/AT/local/generate_audioset_manifest.py
new file mode 100644
index 000000000..1c5b3457c
--- /dev/null
+++ b/egs/audioset/AT/local/generate_audioset_manifest.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file generates the manifest and computes the fbank features for AudioSet
+dataset. The generated manifests and features are stored in data/fbank.
+"""
+
+import argparse
+import csv
+import glob
+import logging
+import os
+from typing import Dict
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.audio import Recording
+from lhotse.cut import MonoCut
+from lhotse.supervision import SupervisionSegment
+
+from icefall.utils import get_executor
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_ID_mapping(csv_file):
+ # get a mapping between class ID and class name
+ mapping = {}
+ with open(csv_file, "r") as fin:
+ reader = csv.reader(fin, delimiter=",")
+ for i, row in enumerate(reader):
+ if i == 0:
+ continue
+ mapping[row[1]] = row[0]
+ return mapping
+
+
+def parse_csv(csv_file: str, id_mapping: Dict):
+ # The content of the csv file shoud be something like this
+ # ------------------------------------------------------
+ # filename label
+ # dataset/AudioSet/balanced/xxxx.wav 0;451
+ # dataset/AudioSet/balanced/xxxy.wav 375
+ # ------------------------------------------------------
+
+ def name2id(names):
+ ids = [id_mapping[name] for name in names.split(",")]
+ return ";".join(ids)
+
+ mapping = {}
+ with open(csv_file, "r") as fin:
+ reader = csv.reader(fin, delimiter=" ")
+ for i, row in enumerate(reader):
+ if i <= 2:
+ continue
+ key = row[0].replace(",", "")
+ mapping[key] = name2id(row[-1])
+ return mapping
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument("--dataset-dir", type=str, default="downloads/audioset")
+
+ parser.add_argument(
+ "--split",
+ type=str,
+ default="balanced",
+ choices=["balanced", "unbalanced", "eval"],
+ )
+
+ parser.add_argument(
+ "--feat-output-dir",
+ type=str,
+ default="data/fbank",
+ )
+
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ dataset_dir = args.dataset_dir
+ split = args.split
+ feat_output_dir = args.feat_output_dir
+
+ num_jobs = min(15, os.cpu_count())
+ num_mel_bins = 80
+
+ if split in ["balanced", "unbalanced"]:
+ csv_file = f"{dataset_dir}/{split}_train_segments.csv"
+ elif split == "eval":
+ csv_file = f"{dataset_dir}/eval_segments.csv"
+ else:
+ raise ValueError()
+
+ class_indices_csv = f"{dataset_dir}/class_labels_indices.csv"
+ id_mapping = get_ID_mapping(class_indices_csv)
+ labels = parse_csv(csv_file, id_mapping)
+
+ audio_files = glob.glob(f"{dataset_dir}/{split}/*.wav")
+
+ new_cuts = []
+ for i, audio in enumerate(audio_files):
+ cut_id = audio.split("/")[-1].split("_")[0]
+ recording = Recording.from_file(audio, cut_id)
+ cut = MonoCut(
+ id=cut_id,
+ start=0.0,
+ duration=recording.duration,
+ channel=0,
+ recording=recording,
+ )
+ supervision = SupervisionSegment(
+ id=cut_id,
+ recording_id=cut.recording.id,
+ start=0.0,
+ channel=0,
+ duration=cut.duration,
+ )
+ try:
+ supervision.audio_event = labels[cut_id]
+ except KeyError:
+ logging.info(f"No labels found for {cut_id}.")
+ continue
+ cut.supervisions = [supervision]
+ new_cuts.append(cut)
+
+ if i % 100 == 0 and i:
+ logging.info(f"Processed {i} cuts until now.")
+
+ cuts = CutSet.from_cuts(new_cuts)
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ logging.info(f"Computing fbank features for {split}")
+ with get_executor() as ex:
+ cuts = cuts.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{feat_output_dir}/{split}_feats",
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ manifest_output_dir = feat_output_dir + "/" + f"cuts_audioset_{split}.jsonl.gz"
+
+ logging.info(f"Storing the manifest to {manifest_output_dir}")
+ cuts.to_jsonl(manifest_output_dir)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh
new file mode 100755
index 000000000..f7f73a008
--- /dev/null
+++ b/egs/audioset/AT/prepare.sh
@@ -0,0 +1,104 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+# run step 0 to step 5 by default
+stage=-1
+stop_stage=4
+
+dl_dir=$PWD/download
+
+# we assume that you have your downloaded the AudioSet and placed
+# it under $dl_dir/audioset, the folder structure should look like
+# this:
+# - $dl_dir/audioset
+# - balanced
+# - eval
+# - unbalanced
+# If you haven't downloaded the AudioSet, please refer to
+# https://github.com/RicherMans/SAT/blob/main/datasets/audioset/1_download_audioset.sh.
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "Running prepare.sh"
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage 0: Download the necessary csv files"
+ if [ ! -e $dl_dir/audioset/.csv.done]; then
+ wget --continue "http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/class_labels_indices.csv" -O "${dl_dir}/audioset/class_labels_indices.csv"
+ wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv -O "${dl_dir}/audioset/balanced_train_segments.csv"
+ wget --continue http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/eval_segments.csv -O "${dl_dir}/audioset/eval_segments.csv"
+ touch $dl_dir/audioset/.csv.done
+ fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
+ fbank_dir=data/fbank
+ if [! -e $fbank_dir/.balanced.done]; then
+ python local/generate_audioset_manifest.py \
+ --dataset-dir $dl_dir/audioset \
+ --split balanced \
+ --feat-output-dir $fbank_dir
+ touch $fbank_dir/.balanced.done
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Construct the audioset manifest and compute the fbank features for unbalanced set"
+ fbank_dir=data/fbank
+ if [! -e $fbank_dir/.unbalanced.done]; then
+ python local/generate_audioset_manifest.py \
+ --dataset-dir $dl_dir/audioset \
+ --split unbalanced \
+ --feat-output-dir $fbank_dir
+ touch $fbank_dir/.unbalanced.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Construct the audioset manifest and compute the fbank features for eval set"
+ fbank_dir=data/fbank
+ if [! -e $fbank_dir/.eval.done]; then
+ python local/generate_audioset_manifest.py \
+ --dataset-dir $dl_dir/audioset \
+ --split eval \
+ --feat-output-dir $fbank_dir
+ touch $fbank_dir/.eval.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.musan.done ]; then
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.musan.done
+ fi
+fi
diff --git a/egs/audioset/AT/shared b/egs/audioset/AT/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/audioset/AT/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/audioset/AT/zipformer/at_datamodule.py
similarity index 86%
rename from egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
rename to egs/audioset/AT/zipformer/at_datamodule.py
index cafa4111d..66497c1ca 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
+++ b/egs/audioset/AT/zipformer/at_datamodule.py
@@ -1,7 +1,6 @@
-# Copyright 2021 Piotr Żelasko
-# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
-# See ../../../../LICENSE for clarification regarding multiple authors
+# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import argparse
import inspect
import logging
@@ -26,12 +24,12 @@ from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ AudioTaggingDataset,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
- K2SpeechRecognitionDataset,
PrecomputedFeatures,
- SingleCutSampler,
+ SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@@ -52,14 +50,12 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
-class CommonVoiceAsrDataModule:
+class AudioSetATDatamodule:
"""
- DataModule for k2 ASR experiments.
- It assumes there is always one train and valid dataloader,
- but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
- and test-other).
+ DataModule for k2 audio tagging (AT) experiments.
- It contains all the common data pipeline modules used in ASR
+
+ It contains all the common data pipeline modules used in AT
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
@@ -67,7 +63,7 @@ class CommonVoiceAsrDataModule:
- augmentation,
- on-the-fly feature extraction
- This class should be derived for specific corpora used in ASR tasks.
+ This class should be derived for specific corpora used in AT tasks.
"""
def __init__(self, args: argparse.Namespace):
@@ -76,7 +72,7 @@ class CommonVoiceAsrDataModule:
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
- title="ASR data related options",
+ title="AT data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
@@ -84,22 +80,17 @@ class CommonVoiceAsrDataModule:
)
group.add_argument(
- "--language",
+ "--audioset-subset",
type=str,
- default="fr",
- help="""Language of Common Voice""",
- )
- group.add_argument(
- "--cv-manifest-dir",
- type=Path,
- default=Path("data/fr/fbank"),
- help="Path to directory with CommonVoice train/dev/test cuts.",
+ default="balanced",
+ choices=["balanced", "full"],
)
+
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
- help="Path to directory with train/valid/test cuts.",
+ help="Path to directory with audioset train/test cuts.",
)
group.add_argument(
"--max-duration",
@@ -218,7 +209,7 @@ class CommonVoiceAsrDataModule:
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
- ) -> DataLoader:
+ ):
"""
Args:
cuts_train:
@@ -232,7 +223,7 @@ class CommonVoiceAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -278,7 +269,7 @@ class CommonVoiceAsrDataModule:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
- train = K2SpeechRecognitionDataset(
+ train = AudioTaggingDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
@@ -296,7 +287,7 @@ class CommonVoiceAsrDataModule:
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
- train = K2SpeechRecognitionDataset(
+ train = AudioTaggingDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
@@ -313,11 +304,12 @@ class CommonVoiceAsrDataModule:
drop_last=self.args.drop_last,
)
else:
- logging.info("Using SingleCutSampler.")
- train_sampler = SingleCutSampler(
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
+ drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader")
@@ -352,13 +344,13 @@ class CommonVoiceAsrDataModule:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
- validate = K2SpeechRecognitionDataset(
+ validate = AudioTaggingDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
- validate = K2SpeechRecognitionDataset(
+ validate = AudioTaggingDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
@@ -380,7 +372,7 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
- test = K2SpeechRecognitionDataset(
+ test = AudioTaggingDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
@@ -401,22 +393,28 @@ class CommonVoiceAsrDataModule:
return test_dl
@lru_cache()
- def train_cuts(self) -> CutSet:
- logging.info("About to get train cuts")
- return load_manifest_lazy(
- self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
+ def audioset_train_cuts(self) -> CutSet:
+ logging.info("About to get the audioset training cuts.")
+ balanced_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
)
+ if self.args.audioset_subset == "full":
+ unbalanced_cuts = load_manifest_lazy(
+ self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
+ )
+ cuts = CutSet.mux(
+ balanced_cuts,
+ unbalanced_cuts,
+ weights=[20000, 2000000],
+ stop_early=True,
+ )
+ else:
+ cuts = balanced_cuts
+ return cuts
@lru_cache()
- def dev_cuts(self) -> CutSet:
- logging.info("About to get dev cuts")
+ def audioset_eval_cuts(self) -> CutSet:
+ logging.info("About to get audioset eval cuts")
return load_manifest_lazy(
- self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
- )
-
- @lru_cache()
- def test_cuts(self) -> CutSet:
- logging.info("About to get test cuts")
- return load_manifest_lazy(
- self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
+ self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
)
diff --git a/egs/audioset/AT/zipformer/encoder_interface.py b/egs/audioset/AT/zipformer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/audioset/AT/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/audioset/AT/zipformer/evaluate.py b/egs/audioset/AT/zipformer/evaluate.py
new file mode 100644
index 000000000..b52a284d0
--- /dev/null
+++ b/egs/audioset/AT/zipformer/evaluate.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0"
+
+./zipformer/evaluate.py \
+ --epoch 50 \
+ --avg 10 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+
+"""
+
+import argparse
+import csv
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from at_datamodule import AudioSetATDatamodule
+from lhotse import load_manifest
+
+try:
+ from sklearn.metrics import average_precision_score
+except Exception as ex:
+ raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
+from train import add_model_arguments, get_model, get_params, str2multihot
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def inference_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ batch: dict,
+):
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3, feature.shape
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ audio_event = supervisions["audio_event"]
+
+ label, _ = str2multihot(audio_event)
+ label = label.detach().cpu()
+
+ feature_lens = supervisions["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ audio_logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
+ # convert to probabilities between 0-1
+ audio_logits = audio_logits.sigmoid().detach().cpu()
+
+ return audio_logits, label
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+) -> Dict:
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ all_logits = []
+ all_labels = []
+
+ for batch_idx, batch in enumerate(dl):
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ num_cuts += len(cut_ids)
+
+ audio_logits, labels = inference_one_batch(
+ params=params,
+ model=model,
+ batch=batch,
+ )
+
+ all_logits.append(audio_logits)
+ all_labels.append(labels)
+
+ if batch_idx % 20 == 1:
+ logging.info(f"Processed {num_cuts} cuts already.")
+ logging.info("Finish collecting audio logits")
+
+ return all_logits, all_labels
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AudioSetATDatamodule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "inference_audio_tagging"
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Evaluation started")
+
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info("About to create model")
+
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+
+ model.to(device)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ args.return_cuts = True
+ audioset = AudioSetATDatamodule(args)
+
+ audioset_cuts = audioset.audioset_eval_cuts()
+
+ audioset_dl = audioset.valid_dataloaders(audioset_cuts)
+
+ test_sets = ["audioset_eval"]
+
+ logits, labels = decode_dataset(
+ dl=audioset_dl,
+ params=params,
+ model=model,
+ )
+
+ logits = torch.cat(logits, dim=0).squeeze(dim=1).detach().numpy()
+ labels = torch.cat(labels, dim=0).long().detach().numpy()
+
+ # compute the metric
+ mAP = average_precision_score(
+ y_true=labels,
+ y_score=logits,
+ )
+
+ logging.info(f"mAP for audioset eval is: {mAP}")
+
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py
new file mode 100755
index 000000000..24b7717b4
--- /dev/null
+++ b/egs/audioset/AT/zipformer/export-onnx.py
@@ -0,0 +1,412 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script exports a transducer model from PyTorch to ONNX.
+
+Usage of this script:
+
+ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
+ repo=$(basename $repo_url)
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ pushd $repo/exp
+ git lfs pull --include pretrained.pt
+ ln -s pretrained.pt epoch-99.pt
+ popd
+
+ python3 zipformer/export-onnx.py \
+ --exp-dir $repo/exp \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0
+
+ pushd $repo/exp
+ mv model-epoch-99-avg-1.onnx model.onnx
+ mv model-epoch-99-avg-1.int8.onnx model.int8.onnx
+ popd
+
+See ./onnx_pretrained.py
+use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+import k2
+import onnx
+import onnxoptimizer
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from onnxsim import simplify
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_model, get_params
+from zipformer import Zipformer2
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxAudioTagger(nn.Module):
+ """A wrapper for Zipformer audio tagger"""
+
+ def __init__(
+ self, encoder: Zipformer2, encoder_embed: nn.Module, classifier: nn.Linear
+ ):
+ """
+ Args:
+ encoder:
+ A Zipformer encoder.
+ encoder_proj:
+ The projection layer for encoder from the joiner.
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+ self.classifier = classifier
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """Please see the help information of Zipformer.forward
+
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 1-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a tensor containing:
+ - probs, A 2-D tensor of shape (N, num_classes)
+
+ """
+ x, x_lens = self.encoder_embed(x, x_lens)
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2)
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (N,T,C)
+
+ logits = self.classifier(encoder_out) # (N, T, num_classes)
+ # Note that this is slightly different from model.py for better
+ # support of onnx
+ logits = logits.mean(dim=1)
+ probs = logits.sigmoid()
+ return probs
+
+
+def export_audio_tagging_model_onnx(
+ model: OnnxAudioTagger,
+ filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given encoder model to ONNX format.
+ The exported model has two inputs:
+
+ - x, a tensor of shape (N, T, C); dtype is torch.float32
+ - x_lens, a tensor of shape (N,); dtype is torch.int64
+
+ and it has two outputs:
+
+ - encoder_out, a tensor of shape (N, T', joiner_dim)
+ - encoder_out_lens, a tensor of shape (N,)
+
+ Args:
+ model:
+ The input encoder model
+ filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ x = torch.zeros(1, 200, 80, dtype=torch.float32)
+ x_lens = torch.tensor([200], dtype=torch.int64)
+
+ model = torch.jit.trace(model, (x, x_lens))
+
+ torch.onnx.export(
+ model,
+ (x, x_lens),
+ filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["x", "x_lens"],
+ output_names=["logits"],
+ dynamic_axes={
+ "x": {0: "N", 1: "T"},
+ "x_lens": {0: "N"},
+ "probs": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "zipformer2",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "zipformer2 audio tagger",
+ "url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=filename, meta_data=meta_data)
+
+
+def optimize_model(filename):
+ # see
+ # https://github.com/microsoft/onnxruntime/issues/1899#issuecomment-534806537
+ # and
+ # https://github.com/onnx/onnx/issues/582#issuecomment-937788108
+ # and
+ # https://github.com/onnx/optimizer/issues/110
+ # and
+ # https://qiita.com/Yossy_Hal/items/34f3b2aef2199baf7f5f
+ passes = ["eliminate_unused_initializer"]
+ onnx_model = onnx.load(filename)
+ onnx_model = onnxoptimizer.optimize(onnx_model, passes)
+
+ model_simp, check = simplify(onnx_model)
+ if check:
+ logging.info("Simplified the model!")
+ onnx_model = model_simp
+ else:
+ logging.info("Failed to simplify the model!")
+
+ onnx.save(onnx_model, filename)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
+
+ model = OnnxAudioTagger(
+ encoder=model.encoder,
+ encoder_embed=model.encoder_embed,
+ classifier=model.classifier,
+ )
+
+ model_num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"total parameters: {model_num_param}")
+
+ if params.iter > 0:
+ suffix = f"iter-{params.iter}"
+ else:
+ suffix = f"epoch-{params.epoch}"
+
+ suffix += f"-avg-{params.avg}"
+
+ opset_version = 13
+
+ logging.info("Exporting audio tagging model")
+ model_filename = params.exp_dir / f"model-{suffix}.onnx"
+ export_audio_tagging_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ optimize_model(model_filename)
+ logging.info(f"Exported audio tagging model to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"model-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+ optimize_model(model_filename_int8)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py
new file mode 100755
index 000000000..6ceeca8de
--- /dev/null
+++ b/egs/audioset/AT/zipformer/export.py
@@ -0,0 +1,340 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is an example for AudioSet dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+and https://github.com/k2-fsa/sherpa-onnx
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 30 \
+ --avg 9
+
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `zipformer/evaluate.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/audioset/AT
+ ./zipformer/evaluate.py \
+ --exp-dir ./zipformer/exp \
+ --use-averaged-model False \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600
+
+Check ./pretrained.py for its usage.
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Tuple
+
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class Classifier(nn.Module):
+ """A wrapper for audio tagging classifier"""
+
+ def __init__(self, classifier: nn.Module) -> None:
+ super().__init__()
+ self.classifier = classifier
+
+ def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
+ """
+ Args:
+ encoder_out:
+ A 3-D tensor of shape (N, T, C).
+ encoder_out_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ """
+ logits = self.classifier(encoder_out) # (N, T, num_classes)
+ padding_mask = make_pad_mask(encoder_out_lens)
+ logits[padding_mask] = 0
+ logits = logits.sum(dim=1) # mask the padding frames
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
+ logits
+ ) # normalize the logits
+
+ return logits
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+
+ logging.info(f"device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ model.classifier = Classifier(model.classifier)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/audioset/AT/zipformer/jit_pretrained.py b/egs/audioset/AT/zipformer/jit_pretrained.py
new file mode 100755
index 000000000..403308fcf
--- /dev/null
+++ b/egs/audioset/AT/zipformer/jit_pretrained.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
+# 2024 Xiaoyu Yang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+Usage of this script:
+
+ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
+ repo=$(basename $repo_url)
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ pushd $repo/exp
+ git lfs pull --include jit_script.pt
+ popd
+
+ python3 zipformer/jit_pretrained.py \
+ --nn-model-filename $repo/exp/jit_script.pt \
+ --label-dict $repo/data/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+"""
+
+import argparse
+import csv
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model-filename",
+ type=str,
+ required=True,
+ help="Path to the torchscript model cpu_jit.pt",
+ )
+
+ parser.add_argument(
+ "--label-dict",
+ type=str,
+ help="""class_labels_indices.csv.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = torch.jit.load(args.nn_model_filename)
+
+ model.eval()
+
+ model.to(device)
+
+ # get the label dictionary
+ label_dict = {}
+ with open(args.label_dict, "r") as f:
+ reader = csv.reader(f, delimiter=",")
+ for i, row in enumerate(reader):
+ if i == 0:
+ continue
+ label_dict[int(row[0])] = row[2]
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ features=features,
+ feature_lengths=feature_lengths,
+ )
+
+ logits = model.classifier(encoder_out, encoder_out_lens)
+
+ for filename, logit in zip(args.sound_files, logits):
+ topk_prob, topk_index = logit.sigmoid().topk(5)
+ topk_labels = [label_dict[index.item()] for index in topk_index]
+ logging.info(
+ f"{filename}: Top 5 predicted labels are {topk_labels} with "
+ f"probability of {topk_prob.tolist()}"
+ )
+
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/audioset/AT/zipformer/model.py b/egs/audioset/AT/zipformer/model.py
new file mode 100644
index 000000000..f189eac62
--- /dev/null
+++ b/egs/audioset/AT/zipformer/model.py
@@ -0,0 +1,157 @@
+# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang,
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import random
+from typing import List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from encoder_interface import EncoderInterface
+
+from icefall.utils import AttributeDict, make_pad_mask
+
+
+class AudioTaggingModel(nn.Module):
+ def __init__(
+ self,
+ encoder_embed: nn.Module,
+ encoder: EncoderInterface,
+ encoder_dim: int = 384,
+ num_events: int = 527,
+ ):
+ """An audio tagging model
+
+ Args:
+ encoder_embed:
+ It is a Convolutional 2D subsampling module. It converts
+ an input of shape (N, T, idim) to an output of of shape
+ (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
+ `logit_lens` of shape (N,).
+ encoder_dim:
+ Dimension of the encoder.
+ num_event:
+ The number of classes.
+ """
+ super().__init__()
+
+ assert isinstance(encoder, EncoderInterface), type(encoder)
+
+ self.encoder_embed = encoder_embed
+ self.encoder = encoder
+ self.encoder_dim = encoder_dim
+
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.1),
+ nn.Linear(encoder_dim, num_events),
+ )
+
+ # for multi-class classification
+ self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
+
+ def forward_encoder(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute encoder outputs.
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+
+ Returns:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ """
+ # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
+ x, x_lens = self.encoder_embed(x, x_lens)
+ # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
+
+ return encoder_out, encoder_out_lens
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ target: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ target:
+ The ground truth label of audio events, could be many hot
+ Returns:
+ Return the binary crossentropy loss
+ """
+ assert x.ndim == 3, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+
+ # Compute encoder outputs
+ encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
+
+ # Forward the speaker module
+ logits = self.forward_audio_tagging(
+ encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
+ ) # (N, num_classes)
+
+ loss = self.criterion(logits, target)
+
+ return loss
+
+ def forward_audio_tagging(self, encoder_out, encoder_out_lens):
+ """
+ Args:
+ encoder_out:
+ A 3-D tensor of shape (N, T, C).
+ encoder_out_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+
+ Returns:
+ A 3-D tensor of shape (N, num_classes).
+ """
+ logits = self.classifier(encoder_out) # (N, T, num_classes)
+ padding_mask = make_pad_mask(encoder_out_lens)
+ logits[padding_mask] = 0
+ logits = logits.sum(dim=1) # mask the padding frames
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
+ logits
+ ) # normalize the logits
+
+ return logits
diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py
new file mode 100755
index 000000000..82fa3d45b
--- /dev/null
+++ b/egs/audioset/AT/zipformer/onnx_pretrained.py
@@ -0,0 +1,228 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2022 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX models and uses them to decode waves.
+
+Usage of this script:
+
+ repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09
+ repo=$(basename $repo_url)
+ git clone $repo_url
+ pushd $repo
+ git lfs pull --include "*.onnx"
+ popd
+
+ for m in model.onnx model.int8.onnx; do
+ python3 zipformer/onnx_pretrained.py \
+ --model-filename $repo/model.onnx \
+ --label-dict $repo/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+ done
+"""
+
+import argparse
+import csv
+import logging
+import math
+from typing import List, Tuple
+
+import k2
+import kaldifeat
+import onnxruntime as ort
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model. ",
+ )
+
+ parser.add_argument(
+ "--label-dict",
+ type=str,
+ help="""class_labels_indices.csv.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ nn_model: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.init_model(nn_model)
+
+ def init_model(self, nn_model: str):
+ self.model = ort.InferenceSession(
+ nn_model,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ meta = self.model.get_modelmeta().custom_metadata_map
+ print(meta)
+
+ def __call__(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 2-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a Tensor:
+ - probs, its shape is (N, num_classes)
+ """
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: x.numpy(),
+ self.model.get_inputs()[1].name: x_lens.numpy(),
+ },
+ )
+ return torch.from_numpy(out[0])
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0])
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+ model = OnnxModel(
+ nn_model=args.model_filename,
+ )
+
+ # get the label dictionary
+ label_dict = {}
+ with open(args.label_dict, "r") as f:
+ reader = csv.reader(f, delimiter=",")
+ for i, row in enumerate(reader):
+ if i == 0:
+ continue
+ label_dict[int(row[0])] = row[2]
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = args.sample_rate
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ expected_sample_rate=args.sample_rate,
+ )
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
+ probs = model(features, feature_lengths)
+
+ for filename, prob in zip(args.sound_files, probs):
+ topk_prob, topk_index = prob.topk(5)
+ topk_labels = [label_dict[index.item()] for index in topk_index]
+ logging.info(
+ f"{filename}: Top 5 predicted labels are {topk_labels} with "
+ f"probability of {topk_prob.tolist()}"
+ )
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/audioset/AT/zipformer/optim.py b/egs/audioset/AT/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/audioset/AT/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py
new file mode 100755
index 000000000..bdbd799fa
--- /dev/null
+++ b/egs/audioset/AT/zipformer/pretrained.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+Note: This is an example for the AudioSet dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+Usage of this script:
+
+ repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12
+ repo=$(basename $repo_url)
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ pushd $repo/exp
+ git lfs pull --include pretrained.pt
+ popd
+
+ python3 zipformer/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --label-dict $repo/data/class_labels_indices.csv \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav \
+ $repo/test_wavs/3.wav \
+ $repo/test_wavs/4.wav
+"""
+
+
+import argparse
+import csv
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--label-dict",
+ type=str,
+ help="""class_labels_indices.csv.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ logging.info("Creating model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+
+ # get the label dictionary
+ label_dict = {}
+ with open(params.label_dict, "r") as f:
+ reader = csv.reader(f, delimiter=",")
+ for i, row in enumerate(reader):
+ if i == 0:
+ continue
+ label_dict[int(row[0])] = row[2]
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ # model forward and predict the audio events
+ encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
+ logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
+
+ for filename, logit in zip(args.sound_files, logits):
+ topk_prob, topk_index = logit.sigmoid().topk(5)
+ topk_labels = [label_dict[index.item()] for index in topk_index]
+ logging.info(
+ f"{filename}: Top 5 predicted labels are {topk_labels} with "
+ f"probability of {topk_prob.tolist()}"
+ )
+
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/audioset/AT/zipformer/scaling.py b/egs/audioset/AT/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/audioset/AT/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/audioset/AT/zipformer/scaling_converter.py b/egs/audioset/AT/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/audioset/AT/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/audioset/AT/zipformer/subsampling.py b/egs/audioset/AT/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/audioset/AT/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py
new file mode 100644
index 000000000..0e234c59f
--- /dev/null
+++ b/egs/audioset/AT/zipformer/train.py
@@ -0,0 +1,1186 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --audioset-subset full \
+ --max-duration 1000
+
+
+"""
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from at_datamodule import AudioSetATDatamodule
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AudioTaggingModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model. Do not recommend to use this for AT",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--num-events", type=int, default=527, help="Number of sound events"
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def _str2modulelist(s: str, add_dot: bool = True):
+ if add_dot:
+ return [ss.strip() + "." for ss in s.split(",")] if s is not None else None
+ else:
+ return [ss.strip() for ss in s.split(",")] if s is not None else None
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ model = AudioTaggingModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ num_events=params.num_events,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ events = supervisions[
+ "audio_event"
+ ] # the label indices are in CED format (https://github.com/RicherMans/CED)
+ labels, _ = str2multihot(events, n_classes=params.num_events)
+ labels = labels.to(device)
+
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ with torch.set_grad_enabled(is_training):
+ loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ target=labels,
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def str2multihot(events: List[str], n_classes=527, id_mapping=None):
+ # Convert strings separated by semi-colon to multi-hot class labels
+ # input: ["0;1", "1;2"]
+ # output: torch.tensor([[1,1,0], [0,1,1]])
+ labels = [list(map(int, event.split(";"))) for event in events]
+ batch_size = len(labels)
+ out = torch.zeros(batch_size, n_classes)
+
+ for i, label in enumerate(labels):
+ if id_mapping is not None:
+ label = [id_mapping[lb] for lb in label]
+ out[i, label] = 1
+
+ return out, labels
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = batch["inputs"].size(0)
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ audioset = AudioSetATDatamodule(args)
+ train_cuts = audioset.audioset_train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 30.0:
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = audioset.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = audioset.audioset_eval_cuts()
+ valid_dl = audioset.valid_dataloaders(valid_cuts)
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(
+ batch,
+ params=params,
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AudioSetATDatamodule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/audioset/AT/zipformer/zipformer.py b/egs/audioset/AT/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/audioset/AT/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md
index 2c158d91d..f384f66a0 100644
--- a/egs/commonvoice/ASR/RESULTS.md
+++ b/egs/commonvoice/ASR/RESULTS.md
@@ -1,20 +1,91 @@
## Results
-### GigaSpeech BPE training results (Pruned Stateless Transducer 7)
+
+### Commonvoice Cantonese (zh-HK) Char training results (Zipformer)
+
+See #1546 for more details.
+
+Number of model parameters: 72526519, i.e., 72.53 M
+
+The best CER, for CommonVoice 16.1 (cv-corpus-16.1-2023-12-06/zh-HK) is below:
+
+| | Dev | Test | Note |
+|----------------------|-------|------|--------------------|
+| greedy_search | 1.17 | 1.22 | --epoch 24 --avg 5 |
+| modified_beam_search | 0.98 | 1.11 | --epoch 24 --avg 5 |
+| fast_beam_search | 1.08 | 1.27 | --epoch 24 --avg 5 |
+
+When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (w/o blank penalty),
+the best CER is below:
+
+| | Dev | Test | Note |
+|----------------------|-------|------|--------------------|
+| greedy_search | 42.40 | 42.03| --epoch 24 --avg 5 |
+| modified_beam_search | 39.73 | 39.19| --epoch 24 --avg 5 |
+| fast_beam_search | 42.14 | 41.98| --epoch 24 --avg 5 |
+
+When doing the cross-corpus validation on [MDCC](https://arxiv.org/abs/2201.02419) (with blank penalty set to 2.2),
+the best CER is below:
+
+| | Dev | Test | Note |
+|----------------------|-------|------|----------------------------------------|
+| greedy_search | 39.19 | 39.09| --epoch 24 --avg 5 --blank-penalty 2.2 |
+| modified_beam_search | 37.73 | 37.65| --epoch 24 --avg 5 --blank-penalty 2.2 |
+| fast_beam_search | 37.73 | 37.74| --epoch 24 --avg 5 --blank-penalty 2.2 |
+
+To reproduce the above result, use the following commands for training:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1"
+./zipformer/train_char.py \
+ --world-size 2 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --cv-manifest-dir data/zh-HK/fbank \
+ --language zh-HK \
+ --use-validated-set 1 \
+ --context-size 1 \
+ --max-duration 1000
+```
+
+and the following commands for decoding:
+
+```bash
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./zipformer/decode_char.py \
+ --epoch 24 \
+ --avg 5 \
+ --decoding-method $method \
+ --exp-dir zipformer/exp \
+ --cv-manifest-dir data/zh-HK/fbank \
+ --context-size 1 \
+ --language zh-HK
+done
+```
+
+Detailed experimental results and pre-trained model are available at:
+
+
+
+### CommonVoice English (en) BPE training results (Pruned Stateless Transducer 7)
#### [pruned_transducer_stateless7](./pruned_transducer_stateless7)
-See #997 for more details.
+See #997 for more details.
Number of model parameters: 70369391, i.e., 70.37 M
+Note that the result is obtained using GigaSpeech transcript trained BPE model
+
The best WER, as of 2023-04-17, for Common Voice English 13.0 (cv-corpus-13.0-2023-03-09/en) is below:
Results are:
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 9.96 | 12.54 |
-| modified beam search | 9.86 | 12.48 |
+| greedy_search | 9.96 | 12.54 |
+| modified_beam_search | 9.86 | 12.48 |
To reproduce the above result, use the following commands for training:
@@ -55,10 +126,6 @@ and the following commands for decoding:
Pretrained model is available at
-The tensorboard log for training is available at
-
-
-
### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
@@ -73,9 +140,9 @@ Results are:
| decoding method | Test |
|----------------------|-------|
-| greedy search | 9.95 |
-| modified beam search | 9.57 |
-| fast beam search | 9.67 |
+| greedy_search | 9.95 |
+| modified_beam_search | 9.57 |
+| fast_beam_search | 9.67 |
Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.
diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py
deleted file mode 120000
index 471aa7fb4..000000000
--- a/egs/commonvoice/ASR/local/compile_hlg.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py
new file mode 100755
index 000000000..6512aa68b
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compile_hlg.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+ - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from data/lm/G_n_gram.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lm",
+ type=str,
+ default="G_3_gram",
+ help="""Stem name for LM used in HLG compiling.
+ """,
+ )
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+ lm:
+ The language stem base name.
+
+ Return:
+ An FSA representing HLG.
+ """
+ lexicon = Lexicon(lang_dir)
+ max_token_id = max(lexicon.tokens)
+ logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
+ H = k2.ctc_topo(max_token_id)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
+ logging.info(f"Loading pre-compiled {lm}")
+ d = torch.load(f"{lang_dir}/lm/{lm}.pt")
+ G = k2.Fsa.from_dict(d)
+ else:
+ logging.info(f"Loading {lm}.fst.txt")
+ with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+ torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ # LG.labels[LG.labels >= first_token_disambig_id] = 0
+ # see https://github.com/k2-fsa/k2/pull/1140
+ labels = LG.labels
+ labels[labels >= first_token_disambig_id] = 0
+ LG.labels = labels
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ logging.info("Composing H and LG")
+ # CAUTION: The name of the inner_labels is fixed
+ # to `tokens`. If you want to change it, please
+ # also change other places in icefall that are using
+ # it.
+ HLG = k2.compose(H, LG, inner_labels="tokens")
+
+ logging.info("Connecting LG")
+ HLG = k2.connect(HLG)
+
+ logging.info("Arc sorting LG")
+ HLG = k2.arc_sort(HLG)
+ logging.info(f"HLG.shape: {HLG.shape}")
+
+ return HLG
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "HLG.pt").is_file():
+ logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ HLG = compile_HLG(lang_dir, args.lm)
+ logging.info(f"Saving HLG.pt to {lang_dir}")
+ torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py
deleted file mode 120000
index 462d6d3fb..000000000
--- a/egs/commonvoice/ASR/local/compile_lg.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py
new file mode 100755
index 000000000..76dacb5b2
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compile_lg.py
@@ -0,0 +1,149 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Kang Wei,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input lang_dir and generates LG from
+
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
+
+The generated LG is saved in $lang_dir/LG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+ parser.add_argument(
+ "--lm",
+ type=str,
+ default="G_3_gram",
+ help="""Stem name for LM used in HLG compiling.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+ Return:
+ An FSA representing LG.
+ """
+ lexicon = Lexicon(lang_dir)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
+ logging.info(f"Loading pre-compiled {lm}")
+ d = torch.load(f"{lang_dir}/lm/{lm}.pt")
+ G = k2.Fsa.from_dict(d)
+ else:
+ logging.info(f"Loading {lm}.fst.txt")
+ with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+ torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ # LG.labels[LG.labels >= first_token_disambig_id] = 0
+ # see https://github.com/k2-fsa/k2/pull/1140
+ labels = LG.labels
+ labels[labels >= first_token_disambig_id] = 0
+ LG.labels = labels
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ return LG
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "LG.pt").is_file():
+ logging.info(f"{lang_dir}/LG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ LG = compile_LG(lang_dir, args.lm)
+ logging.info(f"Saving LG.pt to {lang_dir}")
+ torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
index f31b45aa5..aa672609a 100755
--- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
+++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# Copyright 2023 Xiaomi Corp. (Yifan Yang)
+# Copyright 2023-2024 Xiaomi Corp. (Yifan Yang,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -17,7 +18,6 @@
import argparse
import logging
-from datetime import datetime
from pathlib import Path
import torch
@@ -30,6 +30,8 @@ from lhotse import (
set_caching_enabled,
)
+from icefall.utils import str2bool
+
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@@ -41,6 +43,14 @@ torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--subset",
+ type=str,
+ default="train",
+ choices=["train", "validated", "invalidated"],
+ help="""Dataset parts to compute fbank. """,
+ )
+
parser.add_argument(
"--language",
type=str,
@@ -66,28 +76,35 @@ def get_args():
"--num-splits",
type=int,
required=True,
- help="The number of splits of the train subset",
+ help="The number of splits of the subset",
)
parser.add_argument(
"--start",
type=int,
default=0,
- help="Process pieces starting from this number (inclusive).",
+ help="Process pieces starting from this number (included).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
- help="Stop processing pieces until this number (exclusive).",
+ help="Stop processing pieces until this number (excluded).",
+ )
+
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)
return parser.parse_args()
def compute_fbank_commonvoice_splits(args):
- subset = "train"
+ subset = args.subset
num_splits = args.num_splits
language = args.language
output_dir = f"data/{language}/fbank/cv-{language}_{subset}_split_{num_splits}"
@@ -130,6 +147,10 @@ def compute_fbank_commonvoice_splits(args):
keep_overlapping=False, min_duration=None
)
+ if args.perturb_speed:
+ logging.info(f"Doing speed perturb")
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
logging.info("Computing features")
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
diff --git a/egs/commonvoice/ASR/local/prepare_char.py b/egs/commonvoice/ASR/local/prepare_char.py
new file mode 120000
index 000000000..42743b544
--- /dev/null
+++ b/egs/commonvoice/ASR/local/prepare_char.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/prepare_lang.py b/egs/commonvoice/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/commonvoice/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/prepare_lang_fst.py b/egs/commonvoice/ASR/local/prepare_lang_fst.py
new file mode 120000
index 000000000..c5787c534
--- /dev/null
+++ b/egs/commonvoice/ASR/local/prepare_lang_fst.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_fst.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
index 5f6aa3ec0..cc88ef8d7 100755
--- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py
+++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
@@ -21,7 +21,7 @@ import re
from pathlib import Path
from typing import Optional
-from lhotse import CutSet, SupervisionSegment
+from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
@@ -48,8 +48,33 @@ def normalize_text(utt: str, language: str) -> str:
utt = re.sub("’", "'", utt)
if language == "en":
return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
- if language == "fr":
+ elif language == "fr":
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
+ elif language == "pl":
+ return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
+ elif language in ["yue", "zh-HK"]:
+ # Mozilla Common Voice uses both "yue" and "zh-HK" for Cantonese
+ # Not sure why they decided to do this...
+ # None en/zh-yue tokens are manually removed here
+
+ # fmt: off
+ tokens_to_remove = [",", "。", "?", "!", "?", "!", "‘", "、", ",", "\.", ":", ";", "「", "」", "“", "”", "~", "—", "ㄧ", "《", "》", "…", "⋯", "·", "﹒", ".", ":", "︰", "﹖", "(", ")", "-", "~", ";", "", "⠀", "﹔", "/", "A", "B", "–", "‧"]
+
+ # fmt: on
+ utt = utt.upper().replace("\\", "")
+ return re.sub(
+ pattern="|".join([f"[{token}]" for token in tokens_to_remove]),
+ repl="",
+ string=utt,
+ )
+ else:
+ raise NotImplementedError(
+ f"""
+ Text normalization not implemented for language: {language},
+ please consider implementing it in the local/preprocess_commonvoice.py
+ or raise an issue on GitHub to request it.
+ """
+ )
def preprocess_commonvoice(
@@ -111,6 +136,28 @@ def preprocess_commonvoice(
supervisions=m["supervisions"],
).resample(16000)
+ if partition == "validated":
+ logging.warning(
+ """
+ The 'validated' partition contains the data of both 'train', 'dev'
+ and 'test' partitions. We filter out the 'dev' and 'test' partition
+ here.
+ """
+ )
+ dev_ids = src_dir / f"cv-{language}_dev_ids"
+ test_ids = src_dir / f"cv-{language}_test_ids"
+ assert (
+ dev_ids.is_file()
+ ), f"{dev_ids} does not exist, please check stage 1 of the prepare.sh"
+ assert (
+ test_ids.is_file()
+ ), f"{test_ids} does not exist, please check stage 1 of the prepare.sh"
+ dev_ids = dev_ids.read_text().strip().split("\n")
+ test_ids = test_ids.read_text().strip().split("\n")
+ cut_set = cut_set.filter(
+ lambda x: x.supervisions[0].id not in dev_ids + test_ids
+ )
+
# Run data augmentation that needs to be done in the
# time domain.
logging.info(f"Saving to {raw_cuts_path}")
diff --git a/egs/commonvoice/ASR/local/word_segment_yue.py b/egs/commonvoice/ASR/local/word_segment_yue.py
new file mode 100755
index 000000000..35d262d10
--- /dev/null
+++ b/egs/commonvoice/ASR/local/word_segment_yue.py
@@ -0,0 +1,147 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes a text file "data/lang_char/text" as input, the file consist of
+lines each containing a transcript, applies text norm and generates the following
+files in the directory "data/lang_char":
+ - transcript_words.txt
+ - words.txt
+ - words_no_ids.txt
+"""
+
+import argparse
+import logging
+import re
+from pathlib import Path
+from typing import List
+
+import pycantonese
+from preprocess_commonvoice import normalize_text
+from tqdm.auto import tqdm
+
+from icefall.utils import is_cjk, tokenize_by_CJK_char
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="Prepare char lexicon",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--input-file",
+ "-i",
+ default="data/yue/lang_char/text",
+ type=str,
+ help="The input text file",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default="data/yue/lang_char/",
+ type=str,
+ help="The output directory",
+ )
+ parser.add_argument(
+ "--lang",
+ "-l",
+ default="yue",
+ type=str,
+ help="The language",
+ )
+ return parser
+
+
+def get_word_segments(lines: List[str]) -> List[str]:
+ # the current pycantonese segmenter does not handle the case when the input
+ # is code switching, so we need to handle it separately
+
+ new_lines = []
+
+ for line in tqdm(lines, desc="Segmenting lines"):
+ try:
+ if is_cs(line): # code switching
+ segments = []
+ curr_str = ""
+ for segment in tokenize_by_CJK_char(line).split(" "):
+ if segment.strip() == "":
+ continue
+ try:
+ if not is_cjk(segment[0]): # en segment
+ if curr_str:
+ segments.extend(pycantonese.segment(curr_str))
+ curr_str = ""
+ segments.append(segment)
+ else: # zh segment
+ curr_str += segment
+ # segments.extend(pycantonese.segment(segment))
+ except Exception as e:
+ logging.error(f"Failed to process segment: {segment}")
+ raise
+ if curr_str: # process the last segment
+ segments.extend(pycantonese.segment(curr_str))
+ new_lines.append(" ".join(segments) + "\n")
+ else: # not code switching
+ new_lines.append(" ".join(pycantonese.segment(line)) + "\n")
+ except Exception as e:
+ logging.error(f"Failed to process line: {line}")
+ raise e
+ return new_lines
+
+
+def get_words(lines: List[str]) -> List[str]:
+ words = set()
+ for line in tqdm(lines, desc="Getting words"):
+ words.update(line.strip().split(" "))
+ return list(words)
+
+
+def is_cs(line: str) -> bool:
+ english_markers = r"[a-zA-Z]+"
+ return bool(re.search(english_markers, line))
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+
+ input_file = Path(args.input_file)
+ output_dir = Path(args.output_dir)
+ lang = args.lang
+
+ assert input_file.is_file(), f"{input_file} does not exist"
+ assert output_dir.is_dir(), f"{output_dir} does not exist"
+
+ lines = input_file.read_text(encoding="utf-8").strip().split("\n")
+ norm_lines = [normalize_text(line, lang) for line in lines]
+
+ text_words_segments = get_word_segments(norm_lines)
+ with open(output_dir / "transcript_words.txt", "w", encoding="utf-8") as f:
+ f.writelines(text_words_segments)
+
+ words = get_words(text_words_segments)[1:] # remove "\n" from words
+ with open(output_dir / "words_no_ids.txt", "w", encoding="utf-8") as f:
+ f.writelines([word + "\n" for word in sorted(words)])
+
+ words = (
+ ["", "!SIL", "", ""]
+ + sorted(words)
+ + ["#0", "", "<\s>"]
+ )
+
+ with open(output_dir / "words.txt", "w", encoding="utf-8") as f:
+ f.writelines([f"{word} {i}\n" for i, word in enumerate(words)])
diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh
index edac0e8e6..4e76ef041 100755
--- a/egs/commonvoice/ASR/prepare.sh
+++ b/egs/commonvoice/ASR/prepare.sh
@@ -10,6 +10,12 @@ stop_stage=100
# This is to avoid OOM during feature extraction.
num_splits=1000
+# In case you want to use all validated data
+use_validated=false
+
+# In case you are willing to take the risk and use invalidated data
+use_invalidated=false
+
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
@@ -38,6 +44,7 @@ num_splits=1000
dl_dir=$PWD/download
release=cv-corpus-12.0-2022-12-07
lang=fr
+perturb_speed=false
. shared/parse_options.sh || exit 1
@@ -100,8 +107,40 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
mkdir -p data/${lang}/manifests
if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then
lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests
+
+ if [ $use_validated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.validated.done ]; then
+ log "Also prepare validated data"
+ lhotse prepare commonvoice \
+ --split validated \
+ --language $lang \
+ -j $nj $dl_dir/$release data/${lang}/manifests
+ touch data/${lang}/manifests/.cv-${lang}.validated.done
+ fi
+
+ if [ $use_invalidated = true ] && [ ! -f data/${lang}/manifests/.cv-${lang}.invalidated.done ]; then
+ log "Also prepare invalidated data"
+ lhotse prepare commonvoice \
+ --split invalidated \
+ --language $lang \
+ -j $nj $dl_dir/$release data/${lang}/manifests
+ touch data/${lang}/manifests/.cv-${lang}.invalidated.done
+ fi
+
touch data/${lang}/manifests/.cv-${lang}.done
fi
+
+ # Note: in Linux, you can install jq with the following command:
+ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ # 2. chmod +x ./jq
+ # 3. cp jq /usr/bin
+ if [ $use_validated = true ]; then
+ log "Getting cut ids from dev/test sets for later use"
+ gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_test.jsonl.gz \
+ | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_test_ids
+
+ gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_dev.jsonl.gz \
+ | jq '.id' | sed 's/"//g' > data/${lang}/manifests/cv-${lang}_dev_ids
+ fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -121,6 +160,18 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
./local/preprocess_commonvoice.py --language $lang
touch data/${lang}/fbank/.preprocess_complete
fi
+
+ if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.validated.preprocess_complete ]; then
+ log "Also preprocess validated data"
+ ./local/preprocess_commonvoice.py --language $lang --dataset validated
+ touch data/${lang}/fbank/.validated.preprocess_complete
+ fi
+
+ if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.invalidated.preprocess_complete ]; then
+ log "Also preprocess invalidated data"
+ ./local/preprocess_commonvoice.py --language $lang --dataset invalidated
+ touch data/${lang}/fbank/.invalidated.preprocess_complete
+ fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
@@ -139,6 +190,20 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir
touch $split_dir/.cv-${lang}_train_split.done
fi
+
+ split_dir=data/${lang}/fbank/cv-${lang}_validated_split_${num_splits}
+ if [ $use_validated = true ] && [ ! -f $split_dir/.cv-${lang}_validated.done ]; then
+ log "Also split validated data"
+ lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_validated_raw.jsonl.gz $split_dir
+ touch $split_dir/.cv-${lang}_validated.done
+ fi
+
+ split_dir=data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits}
+ if [ $use_invalidated = true ] && [ ! -f $split_dir/.cv-${lang}_invalidated.done ]; then
+ log "Also split invalidated data"
+ lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_invalidated_raw.jsonl.gz $split_dir
+ touch $split_dir/.cv-${lang}_invalidated.done
+ fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -149,9 +214,36 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
--batch-duration 200 \
--start 0 \
--num-splits $num_splits \
- --language $lang
+ --language $lang \
+ --perturb-speed $perturb_speed
touch data/${lang}/fbank/.cv-${lang}_train.done
fi
+
+ if [ $use_validated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then
+ log "Also compute features for validated data"
+ ./local/compute_fbank_commonvoice_splits.py \
+ --subset validated \
+ --num-workers $nj \
+ --batch-duration 200 \
+ --start 0 \
+ --num-splits $num_splits \
+ --language $lang \
+ --perturb-speed $perturb_speed
+ touch data/${lang}/fbank/.cv-${lang}_validated.done
+ fi
+
+ if [ $use_invalidated = true ] && [ ! -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then
+ log "Also compute features for invalidated data"
+ ./local/compute_fbank_commonvoice_splits.py \
+ --subset invalidated \
+ --num-workers $nj \
+ --batch-duration 200 \
+ --start 0 \
+ --num-splits $num_splits \
+ --language $lang \
+ --perturb-speed $perturb_speed
+ touch data/${lang}/fbank/.cv-${lang}_invalidated.done
+ fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
@@ -160,6 +252,20 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
pieces=$(find data/${lang}/fbank/cv-${lang}_train_split_${num_splits} -name "cv-${lang}_cuts_train.*.jsonl.gz")
lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz
fi
+
+ if [ $use_validated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_validated.done ]; then
+ log "Also combine features for validated data"
+ pieces=$(find data/${lang}/fbank/cv-${lang}_validated_split_${num_splits} -name "cv-${lang}_cuts_validated.*.jsonl.gz")
+ lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_validated.jsonl.gz
+ touch data/${lang}/fbank/.cv-${lang}_validated.done
+ fi
+
+ if [ $use_invalidated = true ] && [ -f data/${lang}/fbank/.cv-${lang}_invalidated.done ]; then
+ log "Also combine features for invalidated data"
+ pieces=$(find data/${lang}/fbank/cv-${lang}_invalidated_split_${num_splits} -name "cv-${lang}_cuts_invalidated.*.jsonl.gz")
+ lhotse combine $pieces data/${lang}/fbank/cv-${lang}_cuts_invalidated.jsonl.gz
+ touch data/${lang}/fbank/.cv-${lang}_invalidated.done
+ fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
@@ -172,83 +278,134 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
- log "Stage 9: Prepare BPE based lang"
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then
+ log "Stage 9: Prepare Char based lang"
+ lang_dir=data/${lang}/lang_char/
mkdir -p $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
- log "Generate data for BPE training"
- file=$(
- find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz"
- )
- gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
+ log "Generate data for lang preparation"
- # Ensure space only appears once
- sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
- sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
- fi
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ # 2. chmod +x ./jq
+ # 3. cp jq /usr/bin
+ if [ $use_validated = true ]; then
+ gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_validated.jsonl.gz \
+ | jq '.text' | sed 's/"//g' >> $lang_dir/text
+ else
+ gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_train.jsonl.gz \
+ | jq '.text' | sed 's/"//g' > $lang_dir/text
+ fi
+
+ if [ $use_invalidated = true ]; then
+ gunzip -c data/${lang}/manifests/cv-${lang}_supervisions_invalidated.jsonl.gz \
+ | jq '.text' | sed 's/"//g' >> $lang_dir/text
+ fi
- if [ ! -f $lang_dir/words.txt ]; then
- cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
- | sort -u | sed '/^$/d' > $lang_dir/words.txt
- (echo '!SIL'; echo ''; echo ''; ) |
- cat - $lang_dir/words.txt | sort | uniq | awk '
- BEGIN {
- print " 0";
- }
- {
- if ($1 == "") {
- print " is in the vocabulary!" | "cat 1>&2"
- exit 1;
+ if [ $lang == "yue" ] || [ $lang == "zh-HK" ]; then
+ # Get words.txt and words_no_ids.txt
+ ./local/word_segment_yue.py \
+ --input-file $lang_dir/text \
+ --output-dir $lang_dir \
+ --lang $lang
+
+ mv $lang_dir/text $lang_dir/_text
+ cp $lang_dir/transcript_words.txt $lang_dir/text
+
+ if [ ! -f $lang_dir/tokens.txt ]; then
+ ./local/prepare_char.py --lang-dir $lang_dir
+ fi
+ else
+ log "word_segment_${lang}.py not implemented yet"
+ exit 1
+ fi
+ fi
+ else
+ log "Stage 9: Prepare BPE based lang"
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ file=$(
+ find "data/${lang}/fbank/cv-${lang}_cuts_train.jsonl.gz"
+ )
+ # Prepare text.
+ # Note: in Linux, you can install jq with the following command:
+ # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
+ # 2. chmod +x ./jq
+ # 3. cp jq /usr/bin
+ gunzip -c ${file} \
+ | jq '.text' | sed 's/"//g' > $lang_dir/transcript_words.txt
+
+ # Ensure space only appears once
+ sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
+ sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' > $lang_dir/words.txt
+ (echo '!SIL'; echo ''; echo ''; ) |
+ cat - $lang_dir/words.txt | sort | uniq | awk '
+ BEGIN {
+ print " 0";
}
- if ($1 == "") {
- print " is in the vocabulary!" | "cat 1>&2"
- exit 1;
+ {
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ printf("%s %d\n", $1, NR);
}
- printf("%s %d\n", $1, NR);
- }
- END {
- printf("#0 %d\n", NR+1);
- printf(" %d\n", NR+2);
- printf(" %d\n", NR+3);
- }' > $lang_dir/words || exit 1;
- mv $lang_dir/words $lang_dir/words.txt
- fi
+ END {
+ printf("#0 %d\n", NR+1);
+ printf(" %d\n", NR+2);
+ printf(" %d\n", NR+3);
+ }' > $lang_dir/words || exit 1;
+ mv $lang_dir/words $lang_dir/words.txt
+ fi
- if [ ! -f $lang_dir/bpe.model ]; then
- ./local/train_bpe_model.py \
- --lang-dir $lang_dir \
- --vocab-size $vocab_size \
- --transcript $lang_dir/transcript_words.txt
- fi
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
- log "Validating $lang_dir/lexicon.txt"
- ./local/validate_bpe_lexicon.py \
- --lexicon $lang_dir/lexicon.txt \
- --bpe-model $lang_dir/bpe.model
- fi
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
- if [ ! -f $lang_dir/L.fst ]; then
- log "Converting L.pt to L.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L.pt \
- $lang_dir/L.fst
- fi
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
- if [ ! -f $lang_dir/L_disambig.fst ]; then
- log "Converting L_disambig.pt to L_disambig.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L_disambig.pt \
- $lang_dir/L_disambig.fst
- fi
- done
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+ fi
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
@@ -256,49 +413,96 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then
+ lang_dir=data/${lang}/lang_char
mkdir -p $lang_dir/lm
- #3-gram used in building HLG, 4-gram used for LM rescoring
- for ngram in 3 4; do
- if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
- ./shared/make_kn_lm.py \
- -ngram-order ${ngram} \
- -text $lang_dir/transcript_words.txt \
- -lm $lang_dir/lm/${ngram}gram.arpa
- fi
- if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
- python3 -m kaldilm \
- --read-symbol-table="$lang_dir/words.txt" \
- --disambig-symbol='#0' \
- --max-order=${ngram} \
- $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
- fi
+ for ngram in 3 ; do
+ if [ ! -f $lang_dir/lm/${ngram}-gram.unpruned.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order ${ngram} \
+ -text $lang_dir/transcript_words.txt \
+ -lm $lang_dir/lm/${ngram}gram.unpruned.arpa
+ fi
+
+ if [ ! -f $lang_dir/lm/G_${ngram}_gram_char.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=${ngram} \
+ $lang_dir/lm/${ngram}gram.unpruned.arpa \
+ > $lang_dir/lm/G_${ngram}_gram_char.fst.txt
+ fi
+
+ if [ ! -f $lang_dir/lm/HLG.fst ]; then
+ ./local/prepare_lang_fst.py \
+ --lang-dir $lang_dir \
+ --ngram-G $lang_dir/lm/G_${ngram}_gram_char.fst.txt
+ fi
+ done
+ else
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir/lm
+ #3-gram used in building HLG, 4-gram used for LM rescoring
+ for ngram in 3 4; do
+ if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order ${ngram} \
+ -text $lang_dir/transcript_words.txt \
+ -lm $lang_dir/lm/${ngram}gram.arpa
+ fi
+
+ if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=${ngram} \
+ $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
+ fi
+ done
done
- done
+ fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Compile HLG"
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/${lang}/lang_bpe_${vocab_size}
- ./local/compile_hlg.py --lang-dir $lang_dir
+ if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then
+ lang_dir=data/${lang}/lang_char
+ for ngram in 3 ; do
+ if [ ! -f $lang_dir/lm/HLG_${ngram}.fst ]; then
+ ./local/compile_hlg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char
+ fi
+ done
+ else
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir
- # Note If ./local/compile_hlg.py throws OOM,
- # please switch to the following command
- #
- # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
- done
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+ done
+ fi
fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Compile LG"
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/${lang}/lang_bpe_${vocab_size}
- ./local/compile_lg.py --lang-dir $lang_dir
- done
+ if [ $lang == "yue" ] || [ $lang == "zh-TW" ] || [ $lang == "zh-CN" ] || [ $lang == "zh-HK" ]; then
+ lang_dir=data/${lang}/lang_char
+ for ngram in 3 ; do
+ if [ ! -f $lang_dir/lm/LG_${ngram}.fst ]; then
+ ./local/compile_lg.py --lang-dir $lang_dir --lm G_${ngram}_gram_char
+ fi
+ done
+ else
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ ./local/compile_lg.py --lang-dir $lang_dir
+ done
+ fi
fi
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
index 546e9f9dd..a80cfe85e 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -308,6 +308,8 @@ class CommonVoiceAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
@@ -379,9 +381,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
- input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
- if self.args.on_the_fly_feats
- else eval(self.args.input_strategy)(),
+ input_strategy=(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)()
+ ),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
@@ -405,6 +409,22 @@ class CommonVoiceAsrDataModule:
self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
)
+ @lru_cache()
+ def validated_cuts(self) -> CutSet:
+ logging.info("About to get validated cuts (with dev/test removed)")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir
+ / f"cv-{self.args.language}_cuts_validated.jsonl.gz"
+ )
+
+ @lru_cache()
+ def invalidated_cuts(self) -> CutSet:
+ logging.info("About to get invalidated cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir
+ / f"cv-{self.args.language}_cuts_invalidated.jsonl.gz"
+ )
+
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py
index 19c518eaf..f04537660 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_check.py
@@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
-from icefall import is_module_available
+import torch
from onnx_pretrained import OnnxModel
-import torch
+from icefall import is_module_available
def get_parser():
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
index 4aedeffe4..5e98084ec 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
-# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
-# Mingshuang Luo,)
-# Zengwei Yao)
+# Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -79,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -248,7 +250,29 @@ def get_parser():
)
parser.add_argument(
- "--base-lr", type=float, default=0.05, help="The base learning rate."
+ "--use-validated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the validated set for training.
+ This is useful when you want to use more data for training,
+ but not recommended for research purposes.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-invalidated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the invalidated set for training.
+ In case you want to take the risk and utilize more data for training.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.05,
+ help="The base learning rate.",
)
parser.add_argument(
@@ -871,9 +895,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1028,7 +1050,13 @@ def run(rank, world_size, args):
commonvoice = CommonVoiceAsrDataModule(args)
- train_cuts = commonvoice.train_cuts()
+ if args.use_validated_set:
+ train_cuts = commonvoice.validated_cuts()
+ else:
+ train_cuts = commonvoice.train_cuts()
+
+ if args.use_invalidated_set:
+ train_cuts += commonvoice.invalidated_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py
new file mode 120000
index 000000000..c274de28a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
index 30f7c1e77..7ae4f1894 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
-# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
-# Zengwei Yao)
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -112,6 +113,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@@ -122,7 +124,6 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
-from commonvoice_fr import CommonVoiceAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 0426bc9a3..aefe88f3f 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
-# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
-# Zengwei Yao)
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -55,7 +56,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from commonvoice_fr import CommonVoiceAsrDataModule
+from asr_datamodule import CommonVoiceAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@@ -889,9 +890,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise RuntimeError(f", exiting: {cur_grad_scale}")
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1037,7 +1036,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
index 3a10c5d81..976004eca 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
-# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
-# Mingshuang Luo,)
-# Zengwei Yao)
+# Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -58,7 +59,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from commonvoice_fr import CommonVoiceAsrDataModule
+from asr_datamodule import CommonVoiceAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@@ -81,6 +82,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -965,9 +967,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1120,7 +1120,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index 018736d26..bb1c093c8 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
-# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -37,7 +39,7 @@ import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
-from commonvoice_fr import CommonVoiceAsrDataModule
+from asr_datamodule import CommonVoiceAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
index a9bc9c2a2..67e1a8133 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
-# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
-# Mingshuang Luo,)
-# Zengwei Yao)
+# Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -55,7 +56,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from commonvoice_fr import CommonVoiceAsrDataModule
+from asr_datamodule import CommonVoiceAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@@ -78,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -264,7 +266,29 @@ def get_parser():
)
parser.add_argument(
- "--base-lr", type=float, default=0.05, help="The base learning rate."
+ "--use-validated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the validated set for training.
+ This is useful when you want to use more data for training,
+ but not recommended for research purposes.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-invalidated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the invalidated set for training.
+ In case you want to take the risk and utilize more data for training.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.05,
+ help="The base learning rate.",
)
parser.add_argument(
@@ -888,9 +912,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1036,7 +1058,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@@ -1045,7 +1067,13 @@ def run(rank, world_size, args):
commonvoice = CommonVoiceAsrDataModule(args)
- train_cuts = commonvoice.train_cuts()
+ if not args.use_validated_set:
+ train_cuts = commonvoice.train_cuts()
+ else:
+ train_cuts = commonvoice.validated_cuts()
+
+ if args.use_invalidated_set:
+ train_cuts += commonvoice.invalidated_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/commonvoice/ASR/zipformer/asr_datamodule.py b/egs/commonvoice/ASR/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..c274de28a
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/beam_search.py b/egs/commonvoice/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..8e2c0a65c
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/decode.py b/egs/commonvoice/ASR/zipformer/decode.py
new file mode 100755
index 000000000..7fd6d0ccd
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/decode.py
@@ -0,0 +1,1052 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ test_cuts = commonvoice.test_cuts()
+ dev_cuts = commonvoice.dev_cuts()
+
+ test_dl = commonvoice.test_dataloaders(test_cuts)
+ dev_dl = commonvoice.valid_dataloaders(dev_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/decode_char.py b/egs/commonvoice/ASR/zipformer/decode_char.py
new file mode 100755
index 000000000..1f8c9c7c6
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/decode_char.py
@@ -0,0 +1,813 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao
+# Mingshuang Luo,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/zh-HK/lang_char \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/zh-HK/lang_char \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search (trivial_graph)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/zh-HK/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(4) fast beam search (LG)
+./zipformer/decode.py \
+ --epoch 30 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/zh-HK/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/zh-HK/lang_char \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/zh-HK/lang_char",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ sentence = "".join([lexicon.word_table[i] for i in hyp])
+ hyps.append(list(sentence))
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ beam=params.beam_size,
+ )
+ for i in range(encoder_out.size(0)):
+ hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append([lexicon.token_table[idx] for idx in hyp])
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search_" + key: hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key += f"_beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ilme_scale_{params.ilme_scale}"
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}_" + key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [list("".join(text.split())) for text in texts]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest_oracle",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"_ilme_scale_{params.ilme_scale}"
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if "LG" in params.decoding_method:
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ dev_cuts = commonvoice.dev_cuts()
+ dev_cuts = dev_cuts.filter(remove_short_utt)
+ dev_dl = commonvoice.valid_dataloaders(dev_cuts)
+
+ test_cuts = commonvoice.test_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ test_dl = commonvoice.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ graph_compiler=graph_compiler,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/decode_stream.py b/egs/commonvoice/ASR/zipformer/decode_stream.py
new file mode 120000
index 000000000..b8d8ddfc4
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decode_stream.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/decoder.py b/egs/commonvoice/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/encoder_interface.py b/egs/commonvoice/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py
new file mode 120000
index 000000000..f9d756352
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/export-onnx-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-ctc.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py
new file mode 120000
index 000000000..652346001
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py b/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/export-onnx.py b/egs/commonvoice/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/export.py b/egs/commonvoice/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/joiner.py b/egs/commonvoice/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/model.py b/egs/commonvoice/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/onnx_check.py b/egs/commonvoice/ASR/zipformer/onnx_check.py
new file mode 120000
index 000000000..f3dd42004
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_check.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/onnx_pretrained.py b/egs/commonvoice/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/optim.py b/egs/commonvoice/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/scaling.py b/egs/commonvoice/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/scaling_converter.py b/egs/commonvoice/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/streaming_beam_search.py b/egs/commonvoice/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 000000000..b1ed54557
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode.py b/egs/commonvoice/ASR/zipformer/streaming_decode.py
new file mode 100755
index 000000000..1d0230c76
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/streaming_decode.py
@@ -0,0 +1,859 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --causal 1 \
+ --chunk-size 32 \
+ --left-context-frames 256 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import sentencepiece as spm
+import torch
+from asr_datamodule import CommonVoiceAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ test_cuts = commonvoice.test_cuts()
+ dev_cuts = commonvoice.dev_cuts()
+
+ test_sets = ["test", "dev"]
+ test_cuts = [test_cuts, dev_cuts]
+
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/streaming_decode_char.py b/egs/commonvoice/ASR/zipformer/streaming_decode_char.py
new file mode 100755
index 000000000..249cba9f5
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/streaming_decode_char.py
@@ -0,0 +1,861 @@
+#!/usr/bin/env python3
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --causal 1 \
+ --chunk-size 32 \
+ --left-context-frames 256 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import CommonVoiceAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/zh-HK/lang_char",
+ help="Path to the lang dir(containing lexicon, tokens, etc.)",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ lexicon.token_table[idx]
+ for idx in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ test_cuts = commonvoice.test_cuts()
+ dev_cuts = commonvoice.dev_cuts()
+
+ test_sets = ["test", "dev"]
+ test_cuts = [test_cuts, dev_cuts]
+
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/subsampling.py b/egs/commonvoice/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py
new file mode 100755
index 000000000..5cda9bfd4
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/train.py
@@ -0,0 +1,1411 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/en/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--use-validated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the validated set for training.
+ This is useful when you want to use more data for training,
+ but not recommended for research purposes.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-invalidated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the invalidated set for training.
+ In case you want to take the risk and utilize more data for training.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.045,
+ help="The base learning rate.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ if not args.use_validated_set:
+ train_cuts = commonvoice.train_cuts()
+ else:
+ train_cuts = commonvoice.validated_cuts()
+
+ if args.use_invalidated_set:
+ train_cuts += commonvoice.invalidated_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ dev_cuts = commonvoice.dev_cuts()
+ dev_dl = commonvoice.valid_dataloaders(dev_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=dev_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py
new file mode 100755
index 000000000..a780bbbbc
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/train_char.py
@@ -0,0 +1,1051 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import CommonVoiceAsrDataModule
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from train import (
+ add_model_arguments,
+ get_adjusted_batch_count,
+ get_model,
+ load_checkpoint_if_available,
+ save_checkpoint,
+ set_batch_count,
+)
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/zh-HK/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--use-validated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the validated set for training.
+ This is useful when you want to use more data for training,
+ but not recommended for research purposes.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-invalidated-set",
+ type=str2bool,
+ default=False,
+ help="""Use the invalidated set for training.
+ In case you want to take the risk and utilize more data for training.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.045,
+ help="The base learning rate.",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ texts = supervisions["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ if not args.use_validated_set:
+ train_cuts = commonvoice.train_cuts()
+ else:
+ train_cuts = commonvoice.validated_cuts()
+
+ if args.use_invalidated_set:
+ train_cuts += commonvoice.invalidated_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = graph_compiler.texts_to_ids([c.supervisions[0].text])[0]
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ dev_cuts = commonvoice.dev_cuts()
+ dev_dl = commonvoice.valid_dataloaders(dev_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=dev_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+# torch.set_num_threads(1)
+# torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/zipformer/zipformer.py b/egs/commonvoice/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/commonvoice/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py
index 042b6ecbf..7bf7bdef0 100644
--- a/egs/csj/ASR/local/utils/asr_datamodule.py
+++ b/egs/csj/ASR/local/utils/asr_datamodule.py
@@ -336,6 +336,8 @@ class CSJAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 685f6ece6..6d256308c 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -909,9 +910,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
index b210430c6..06a0fa96b 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -70,9 +70,9 @@ import logging
from pathlib import Path
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
-from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
index 73fcd67aa..ef7ea9013 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -908,9 +909,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/fluent_speech_commands/SLU/README.md b/egs/fluent_speech_commands/SLU/README.md
new file mode 100755
index 000000000..a203a9bfb
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/README.md
@@ -0,0 +1,9 @@
+## Fluent Speech Commands recipe
+
+This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances.
+
+Dataset Paper link:
+
+cd icefall/egs/fluent_speech_commands/
+Training: python transducer/train.py
+Decoding: python transducer/decode.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py
new file mode 100755
index 000000000..a7df8f966
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/compile_hlg.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python3
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+ - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from data/lm/G.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str) -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+ Return:
+ An FSA representing HLG.
+ """
+ lexicon = Lexicon(lang_dir)
+ max_token_id = max(lexicon.tokens)
+ logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
+ H = k2.ctc_topo(max_token_id)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ logging.info("Loading G.fst.txt")
+ with open(lang_dir / "G.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ # LG.labels[LG.labels >= first_token_disambig_id] = 0
+ # see https://github.com/k2-fsa/k2/pull/1140
+ labels = LG.labels
+ labels[labels >= first_token_disambig_id] = 0
+ LG.labels = labels
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ logging.info("Composing H and LG")
+ # CAUTION: The name of the inner_labels is fixed
+ # to `tokens`. If you want to change it, please
+ # also change other places in icefall that are using
+ # it.
+ HLG = k2.compose(H, LG, inner_labels="tokens")
+
+ logging.info("Connecting LG")
+ HLG = k2.connect(HLG)
+
+ logging.info("Arc sorting LG")
+ HLG = k2.arc_sort(HLG)
+ logging.info(f"HLG.shape: {HLG.shape}")
+
+ return HLG
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "HLG.pt").is_file():
+ logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ HLG = compile_HLG(lang_dir)
+ logging.info(f"Saving HLG.pt to {lang_dir}")
+ torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
new file mode 100755
index 000000000..a51b7b47b
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+
+"""
+This file computes fbank features of the Fluent Speech Commands dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or it wastes a
+# lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_slu(manifest_dir, fbanks_dir):
+ src_dir = Path(manifest_dir)
+ output_dir = Path(fbanks_dir)
+
+ # This dataset is rather small, so we use only one job
+ num_jobs = min(1, os.cpu_count())
+ num_mel_bins = 23
+
+ dataset_parts = (
+ "train",
+ "valid",
+ "test",
+ )
+ prefix = "slu"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
+ if cuts_file.is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 1, # use one job
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(cuts_file)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("manifest_dir")
+parser.add_argument("fbanks_dir")
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ args = parser.parse_args()
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_slu(args.manifest_dir, args.fbanks_dir)
diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py
new file mode 100755
index 000000000..6263e062f
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py
@@ -0,0 +1,59 @@
+import argparse
+
+import pandas
+from tqdm import tqdm
+
+
+def generate_lexicon(corpus_dir, lm_dir):
+ data = pandas.read_csv(
+ str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
+ )
+ vocab_transcript = set()
+ vocab_frames = set()
+ transcripts = data["transcription"].tolist()
+ frames = list(
+ i
+ for i in zip(
+ data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
+ )
+ )
+
+ for transcript in tqdm(transcripts):
+ for word in transcript.split():
+ vocab_transcript.add(word)
+
+ for frame in tqdm(frames):
+ for word in frame:
+ vocab_frames.add("_".join(word.split()))
+
+ with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
+ lexicon_transcript_file.write(" 1" + "\n")
+ lexicon_transcript_file.write(" 2" + "\n")
+ lexicon_transcript_file.write(" 0" + "\n")
+ id = 3
+ for vocab in vocab_transcript:
+ lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
+ id += 1
+
+ with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
+ lexicon_frames_file.write(" 1" + "\n")
+ lexicon_frames_file.write(" 2" + "\n")
+ lexicon_frames_file.write(" 0" + "\n")
+ id = 3
+ for vocab in vocab_frames:
+ lexicon_frames_file.write(vocab + " " + str(id) + "\n")
+ id += 1
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("corpus_dir")
+parser.add_argument("lm_dir")
+
+
+def main():
+ args = parser.parse_args()
+
+ generate_lexicon(args.corpus_dir, args.lm_dir)
+
+
+main()
diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py
new file mode 100755
index 000000000..2a71dcf81
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py
@@ -0,0 +1,371 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "!SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ # assert token2id[""] == 0
+ # assert word2id[""] == 0
+
+ eps = 0
+ sil_token = word2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [word2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = word2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("lm_dir")
+
+
+def main():
+ args = parser.parse_args()
+
+ out_dir = Path(args.lm_dir)
+ lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
+ names = ["frames", "transcript"]
+ sil_token = "!SIL"
+ sil_prob = 0.5
+
+ for name, lexicon_filename in zip(names, lexicon_filenames):
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_words(lexicon)
+ words = get_words(lexicon)
+ new_lexicon = []
+ for lexicon_item in lexicon:
+ new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
+ lexicon = new_lexicon
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ tokens = [""] + tokens
+ words = ["eps"] + words + ["#0", "!SIL"]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
+ write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
+ write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=word2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=word2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
+ torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+main()
diff --git a/egs/fluent_speech_commands/SLU/prepare.sh b/egs/fluent_speech_commands/SLU/prepare.sh
new file mode 100755
index 000000000..3ff339d91
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/prepare.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=1
+stop_stage=5
+
+data_dir=path/to/fluent/speech/commands
+target_root_dir=data/
+
+lang_dir=${target_root_dir}/lang_phone
+lm_dir=${target_root_dir}/lm
+manifest_dir=${target_root_dir}/manifests
+fbanks_dir=${target_root_dir}/fbanks
+
+. shared/parse_options.sh || exit 1
+
+mkdir -p $lang_dir
+mkdir -p $lm_dir
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "data_dir: $data_dir"
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare slu manifest"
+ mkdir -p $manifest_dir
+ lhotse prepare slu $data_dir $manifest_dir
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute fbank for SLU"
+ mkdir -p $fbanks_dir
+ python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare lang"
+ # NOTE: " SIL" is added for implementation convenience
+ # as the graph compiler code requires that there is a OOV word
+ # in the lexicon.
+ python ./local/generate_lexicon.py $data_dir $lm_dir
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Train LM"
+ # We use a unigram G
+ ./shared/make_kn_lm.py \
+ -ngram-order 1 \
+ -text $lm_dir/words_transcript.txt \
+ -lm $lm_dir/G_transcript.arpa
+
+ ./shared/make_kn_lm.py \
+ -ngram-order 1 \
+ -text $lm_dir/words_frames.txt \
+ -lm $lm_dir/G_frames.arpa
+
+ python ./local/prepare_lang.py $lm_dir
+
+ if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
+ python -m kaldilm \
+ --read-symbol-table="$lm_dir/words_transcript.txt" \
+ $lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
+ fi
+
+ if [ ! -f $lm_dir/G_frames.fst.txt ]; then
+ python -m kaldilm \
+ --read-symbol-table="$lm_dir/words_frames.txt" \
+ $lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
+ fi
+
+ mkdir -p $lm_dir/frames
+ mkdir -p $lm_dir/transcript
+
+ chmod -R +777 .
+
+ for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
+ do
+ j=${i//"_frames"/}
+ mv "$lm_dir/$i" $lm_dir/frames/$j
+ done
+
+ for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
+ do
+ j=${i//"_transcript"/}
+ mv "$lm_dir/$i" $lm_dir/transcript/$j
+ done
+fi
+
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compile HLG"
+ ./local/compile_hlg.py --lang-dir $lm_dir/frames
+ ./local/compile_hlg.py --lang-dir $lm_dir/transcript
+
+fi
diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared
new file mode 120000
index 000000000..9115c7e17
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
diff --git a/egs/fluent_speech_commands/SLU/transducer/__init__.py b/egs/fluent_speech_commands/SLU/transducer/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py
new file mode 100755
index 000000000..a16aa0123
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/beam_search.py
@@ -0,0 +1,71 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+from transducer.model import Transducer
+
+
+def greedy_search(
+ model: Transducer, encoder_out: torch.Tensor, id2word: dict
+) -> List[str]:
+ """
+ Args:
+ model:
+ An instance of `Transducer`.
+ encoder_out:
+ A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
+ Returns:
+ Return the decoded result.
+ """
+ assert encoder_out.ndim == 3
+
+ # support only batch_size == 1 for now
+ assert encoder_out.size(0) == 1, encoder_out.size(0)
+ blank_id = model.decoder.blank_id
+ device = model.device
+
+ sos = torch.tensor([blank_id], device=device).reshape(1, 1)
+ decoder_out, (h, c) = model.decoder(sos)
+ T = encoder_out.size(1)
+ t = 0
+ hyp = []
+ max_u = 1000 # terminate after this number of steps
+ u = 0
+
+ while t < T and u < max_u:
+ # fmt: off
+ current_encoder_out = encoder_out[:, t:t+1, :]
+ # fmt: on
+ logits = model.joiner(current_encoder_out, decoder_out)
+
+ log_prob = logits.log_softmax(dim=-1)
+ # log_prob is (N, 1, 1)
+ # TODO: Use logits.argmax()
+ y = log_prob.argmax()
+ if y != blank_id:
+ hyp.append(y.item())
+ y = y.reshape(1, 1)
+ decoder_out, (h, c) = model.decoder(y, (h, c))
+ u += 1
+ else:
+ t += 1
+ # id2word = {1: "YES", 2: "NO"}
+
+ hyp = [id2word[i] for i in hyp]
+
+ return hyp
diff --git a/egs/fluent_speech_commands/SLU/transducer/conformer.py b/egs/fluent_speech_commands/SLU/transducer/conformer.py
new file mode 120000
index 000000000..8be0dc864
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/conformer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py
new file mode 100755
index 000000000..ba2b9aaea
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/decode.py
@@ -0,0 +1,346 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from transducer.beam_search import greedy_search
+from transducer.conformer import Conformer
+from transducer.decoder import Decoder
+from transducer.joiner import Joiner
+from transducer.model import Transducer
+from transducer.slu_datamodule import SluDataModule
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.env import get_env_info
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_id2word(params):
+ id2word = {}
+
+ # 0 is blank
+ id = 1
+ try:
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if len(line.strip()) > 0:
+ id2word[id] = line.split()[0]
+ id += 1
+ except:
+ pass
+
+ return id2word
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=6,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer/exp",
+ help="Directory from which to load the checkpoints",
+ )
+ parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "feature_dim": 23,
+ "lang_dir": Path("data/lm/frames"),
+ # encoder/decoder params
+ "vocab_size": 3, # blank, yes, no
+ "blank_id": 0,
+ "embedding_dim": 32,
+ "hidden_dim": 16,
+ "num_decoder_layers": 4,
+ }
+ )
+
+ vocab_size = 1
+ with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if (
+ len(line.strip()) > 0
+ ): # and '' not in line and '' not in line and '' not in line:
+ vocab_size += 1
+ params.vocab_size = vocab_size
+
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
+) -> List[List[int]]:
+ """Decode one batch and return the result in a list-of-list.
+ Each sub list contains the word IDs for an utterance in the batch.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding.
+ - params.method is "nbest", it uses nbest decoding.
+
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
+ Returns:
+ Return the decoding result. `len(ans)` == batch size.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+ feature_lens = batch["supervisions"]["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ hyps = []
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word)
+ hyps.append(hyp)
+
+ # hyps = [[word_table[i] for i in ids] for ids in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+) -> List[Tuple[List[int], List[int]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ Returns:
+ Return a tuple contains two elements (ref_text, hyp_text):
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ id2word = get_id2word(params)
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = [
+ " ".join(a.supervisions[0].custom["frames"])
+ for a in batch["supervisions"]["cut"]
+ ]
+ texts = [
+ " " + a.replace("change language", "change_language") + " "
+ for a in texts
+ ]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps = decode_one_batch(
+ params=params, model=model, batch=batch, id2word=id2word
+ )
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ exp_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[List[int], List[int]]],
+) -> None:
+ """Save results to `exp_dir`.
+ Args:
+ exp_dir:
+ The output directory. This function create the following files inside
+ this directory:
+
+ - recogs-{test_set_name}.text
+
+ It contains the reference and hypothesis results, like below::
+
+ ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+
+ - errs-{test_set_name}.txt
+
+ It contains the detailed WER.
+ test_set_name:
+ The name of the test set, which will be part of the result filename.
+ results:
+ A list of tuples, each of which contains (ref_words, hyp_words).
+ Returns:
+ Return None.
+ """
+ recog_path = exp_dir / f"recogs-{test_set_name}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = exp_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ write_error_stats(f, f"{test_set_name}", results)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+
+def get_transducer_model(params: AttributeDict):
+ # encoder = Tdnn(
+ # num_features=params.feature_dim,
+ # output_dim=params.hidden_dim,
+ # )
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.hidden_dim,
+ )
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.embedding_dim,
+ blank_id=params.blank_id,
+ num_layers=params.num_decoder_layers,
+ hidden_dim=params.hidden_dim,
+ embedding_dropout=0.4,
+ rnn_dropout=0.4,
+ )
+ joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
+ transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
+ return transducer
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ SluDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params["env_info"] = get_env_info()
+
+ setup_logger(f"{params.exp_dir}/log/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = get_transducer_model(params)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ slu = SluDataModule(args)
+ test_dl = slu.test_dataloaders()
+ results = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ )
+
+ test_set_name = str(args.feature_dir).split("/")[-2]
+ save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/fluent_speech_commands/SLU/transducer/decoder.py b/egs/fluent_speech_commands/SLU/transducer/decoder.py
new file mode 120000
index 000000000..e99310f91
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/decoder.py
@@ -0,0 +1 @@
+../../../yesno/ASR/transducer/decoder.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/joiner.py b/egs/fluent_speech_commands/SLU/transducer/joiner.py
new file mode 120000
index 000000000..75fa64868
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/joiner.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/model.py b/egs/fluent_speech_commands/SLU/transducer/model.py
new file mode 120000
index 000000000..10f6ddad1
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/model.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py
new file mode 100755
index 000000000..fa715abdd
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py
@@ -0,0 +1,289 @@
+# Copyright 2021 Piotr Żelasko
+# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import List
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.dataset.datamodule import DataModule
+from icefall.utils import str2bool
+
+
+class SluDataModule(DataModule):
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ """
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ super().add_arguments(parser)
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--feature-dir",
+ type=Path,
+ default=Path("data/fbanks"),
+ help="Path to directory with train/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=30.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=False,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=10,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ def train_dataloaders(self) -> DataLoader:
+ logging.info("About to get train cuts")
+ cuts_train = self.train_cuts()
+
+ logging.info("About to create train dataset")
+ transforms = []
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ FbankConfig(sampling_rate=8000, num_mel_bins=23)
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self) -> DataLoader:
+ logging.info("About to get valid cuts")
+ cuts_valid = self.valid_cuts()
+
+ logging.debug("About to create valid dataset")
+ valid = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create valid dataloader")
+ valid_dl = DataLoader(
+ valid,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+ return valid_dl
+
+ def test_dataloaders(self) -> DataLoader:
+ logging.info("About to get test cuts")
+ cuts_test = self.test_cuts()
+
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts_test,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ cuts_train = load_manifest_lazy(
+ self.args.feature_dir / "slu_cuts_train.jsonl.gz"
+ )
+ return cuts_train
+
+ @lru_cache()
+ def valid_cuts(self) -> List[CutSet]:
+ logging.info("About to get valid cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.feature_dir / "slu_cuts_valid.jsonl.gz"
+ )
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
+ return cuts_test
diff --git a/egs/fluent_speech_commands/SLU/transducer/subsampling.py b/egs/fluent_speech_commands/SLU/transducer/subsampling.py
new file mode 120000
index 000000000..fd7ca8b30
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/subsampling.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py
new file mode 120000
index 000000000..3060dd70c
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_conformer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py
new file mode 120000
index 000000000..d1bc718ce
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py
@@ -0,0 +1 @@
+../../../yesno/ASR/transducer/test_decoder.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py
new file mode 120000
index 000000000..248222a8a
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_joiner.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py
new file mode 120000
index 000000000..df104bad7
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_transducer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py
new file mode 100755
index 000000000..a59c0b754
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/train.py
@@ -0,0 +1,625 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.optim as optim
+from lhotse.utils import fix_random_seed
+from slu_datamodule import SluDataModule
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from transducer.conformer import Conformer
+
+# from torch.utils.tensorboard import SummaryWriter
+from transducer.decoder import Decoder
+from transducer.joiner import Joiner
+from transducer.model import Transducer
+
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+
+def get_word2id(params):
+ word2id = {}
+
+ # 0 is blank
+ id = 1
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if len(line.strip()) > 0:
+ word2id[line.split()[0]] = id
+ id += 1
+
+ return word2id
+
+
+def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ Returns:
+ Return a ragged tensor containing the corresponding word ID.
+ """
+ # blank is 0
+ word_ids = []
+ for t in texts:
+ words = t.split()
+ ids = [word2id[w] for w in words]
+ word_ids.append(ids)
+
+ return k2.RaggedTensor(word_ids)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=7,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ tdnn/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer/exp",
+ help="Directory to save results",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ is saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - lr: It specifies the initial learning rate
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - start_epoch: If it is not zero, load checkpoint `start_epoch-1`
+ and continue training from that checkpoint.
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+
+ """
+ params = AttributeDict(
+ {
+ "lr": 1e-4,
+ "feature_dim": 23,
+ "weight_decay": 1e-6,
+ "start_epoch": 0,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 100,
+ "reset_interval": 20,
+ "valid_interval": 3000,
+ "exp_dir": Path("transducer/exp"),
+ "lang_dir": Path("data/lm/frames"),
+ # encoder/decoder params
+ "vocab_size": 3, # blank, yes, no
+ "blank_id": 0,
+ "embedding_dim": 32,
+ "hidden_dim": 16,
+ "num_decoder_layers": 4,
+ }
+ )
+
+ vocab_size = 1
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if (
+ len(line.strip()) > 0
+ ): # and '' not in line and '' not in line and '' not in line:
+ vocab_size += 1
+ params.vocab_size = vocab_size
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Tdnn in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ feature_lens = batch["supervisions"]["num_frames"].to(device)
+
+ texts = [
+ " ".join(a.supervisions[0].custom["frames"])
+ for a in batch["supervisions"]["cut"]
+ ]
+ texts = [
+ " " + a.replace("change language", "change_language") + " "
+ for a in texts
+ ]
+ labels = get_labels(texts, word2ids).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss = model(x=feature, x_lens=feature_lens, y=labels)
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ info["frames"] = feature.size(0)
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ valid_dl: torch.utils.data.DataLoader,
+ word2ids,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process. The validation loss
+ is saved in `params.valid_loss`.
+ """
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ word2ids=word2ids,
+ )
+ assert loss.requires_grad is False
+
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ word2ids,
+ tb_writer: None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ loss, loss_info = compute_loss(
+ params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
+ )
+ # summary stats.
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}"
+ )
+ if batch_idx % params.log_interval == 0:
+
+ if tb_writer is not None:
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ word2ids=word2ids,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer,
+ "train/valid_",
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def get_transducer_model(params: AttributeDict):
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.hidden_dim,
+ )
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.embedding_dim,
+ blank_id=params.blank_id,
+ num_layers=params.num_decoder_layers,
+ hidden_dim=params.hidden_dim,
+ embedding_dropout=0.4,
+ rnn_dropout=0.4,
+ )
+ joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
+ transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
+
+ return transducer
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+
+ params.update(vars(args))
+ params["env_info"] = get_env_info()
+
+ word2ids = get_word2id(params)
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ # if args.tensorboard and rank == 0:
+ # tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ # else:
+ # tb_writer = None
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ else:
+ device = torch.device("cpu")
+ logging.info(f"device: {device}")
+
+ model = get_transducer_model(params)
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ model = DDP(model, device_ids=[rank])
+
+ model.device = device
+
+ optimizer = optim.Adam(
+ model.parameters(),
+ lr=params.lr,
+ weight_decay=params.weight_decay,
+ )
+
+ if checkpoints:
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ slu = SluDataModule(args)
+ train_dl = slu.train_dataloaders()
+
+ # There are only 60 waves: 30 files are used for training
+ # and the remaining 30 files are used for testing.
+ # We use test data as validation.
+ valid_dl = slu.test_dataloaders()
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ word2ids=word2ids,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=None,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ SluDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/fluent_speech_commands/SLU/transducer/transformer.py b/egs/fluent_speech_commands/SLU/transducer/transformer.py
new file mode 120000
index 000000000..214afed39
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/transformer.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
index a93e224d5..569978424 100644
--- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
@@ -261,6 +261,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
similarity index 87%
rename from egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
rename to egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
index 07beeb1f0..9e0df0989 100755
--- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
@@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_gigaspeech_dev_test():
+def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
- batch_duration = 600
+ batch_duration = 1000
- subsets = ("DEV", "TEST")
+ subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
@@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")
for partition in subsets:
- cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
+ cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
- raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
+ raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
- storage_path=f"{in_out_dir}/feats_{partition}",
+ storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
@@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- compute_fbank_gigaspeech_dev_test()
+ compute_fbank_gigaspeech()
if __name__ == "__main__":
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
index 1c71be0f9..51cd59078 100755
--- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
@@ -76,7 +76,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
- output_dir = "data/fbank/XL_split"
+ output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@@ -96,15 +96,15 @@ def compute_fbank_gigaspeech_splits(args):
logging.info(f"device: {device}")
for i in range(start, stop):
- idx = f"{i + 1}".zfill(num_digits)
+ idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
- cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
+ cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
- raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
+ raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
- storage_path=f"{output_dir}/feats_XL_{idx}",
+ storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
index 31abe7fff..a31685211 100755
--- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
+++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
@@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
import logging
import re
from pathlib import Path
@@ -23,10 +24,24 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
+from icefall.utils import str2bool
+
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="Whether to use speed perturbation.",
+ )
+
+ return parser.parse_args()
+
+
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@@ -42,7 +57,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None
-def preprocess_giga_speech():
+def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
@@ -51,6 +66,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
+ "L",
+ "M",
+ "S",
+ "XS",
)
logging.info("Loading manifest (may take 4 minutes)")
@@ -71,7 +90,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
- raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
+ raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
@@ -94,11 +113,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
- logging.info(
- f"Speed perturb for {partition} with factors 0.9 and 1.1 "
- "(Perturbing may take 8 minutes and saving may take 20 minutes)"
- )
- cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ if args.perturb_speed:
+ logging.info(
+ f"Speed perturb for {partition} with factors 0.9 and 1.1 "
+ "(Perturbing may take 8 minutes and saving may take 20 minutes)"
+ )
+ cut_set = (
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ )
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
@@ -107,7 +129,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- preprocess_giga_speech()
+ args = get_args()
+ preprocess_giga_speech(args)
if __name__ == "__main__":
diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh
index a23b708d7..5e54b669a 100755
--- a/egs/gigaspeech/ASR/prepare.sh
+++ b/egs/gigaspeech/ASR/prepare.sh
@@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
- lhotse download gigaspeech --subset auto --host tsinghua \
+ lhotse download gigaspeech --subset XL \
+ --subset L \
+ --subset M \
+ --subset S \
+ --subset XS \
+ --subset DEV \
+ --subset TEST \
+ --host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
@@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
- lhotse prepare gigaspeech --subset auto -j $nj \
+ lhotse prepare gigaspeech --subset XL \
+ --subset L \
+ --subset M \
+ --subset S \
+ --subset XS \
+ --subset DEV \
+ --subset TEST \
+ -j $nj \
$dl_dir/GigaSpeech data/manifests
fi
@@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
- log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
- python3 ./local/compute_fbank_gigaspeech_dev_test.py
+ log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
+ python3 ./local/compute_fbank_gigaspeech.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
- log "Stage 9: Prepare phone based lang"
+ log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir
-
- (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
- cat - $dl_dir/lm/lexicon.txt |
- sort | uniq > $lang_dir/lexicon.txt
-
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang.py --lang-dir $lang_dir
- fi
-
if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
@@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
- log "Stage 10: Prepare BPE based lang"
+ log "Stage 10: Prepare phone based lang"
+ lang_dir=data/lang_phone
+ mkdir -p $lang_dir
+
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/lexicon.txt |
+ sort | uniq > $lang_dir/lexicon.txt
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang.py --lang-dir $lang_dir
+ fi
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi
-if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
- log "Stage 11: Prepare bigram P"
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi
-if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 12: Prepare G"
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+ log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi
-if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
- log "Stage 13: Compile HLG"
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+ log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index b5b27ce95..40339365c 100644
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
index 72f74c968..ef430302d 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py
@@ -76,6 +76,7 @@ from beam_search import (
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
+
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
index b6190e8a6..4a44f7bcb 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
@@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@@ -47,12 +47,13 @@ import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -98,10 +99,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -135,12 +136,14 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ # Load tokens.txt here
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ # Load id of the token and the vocab size
# is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
@@ -183,6 +186,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
index 6adfdbfbb..0501461cd 100644
--- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
@@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
- default=30,
+ default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@@ -311,6 +311,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
@@ -366,6 +368,8 @@ class GigaSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
@@ -415,6 +419,7 @@ class GigaSpeechAsrDataModule:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
+
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py
index aa51036d5..651f20cb6 100755
--- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py
+++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py
@@ -88,7 +88,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
index 7cada8c9d..cb3fd0dc7 100755
--- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py
+++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
@@ -51,7 +51,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py
index d93cc221c..f0ad98147 100755
--- a/egs/gigaspeech/ASR/zipformer/train.py
+++ b/egs/gigaspeech/ASR/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -416,6 +417,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
+ parser.add_argument(
+ "--scan-for-oom-batches",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to scan for oom batches before training, this is helpful for
+ finding the suitable max_duration, you only need to run it once.
+ Caution: a little time consuming.
+ """,
+ )
+
parser.add_argument(
"--inf-check",
type=str2bool,
@@ -1020,9 +1032,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
@@ -1171,9 +1181,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
+ def remove_short_utt(c: Cut):
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ return T > 0
+
gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts()
+ train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@@ -1187,9 +1204,10 @@ def run(rank, world_size, args):
)
valid_cuts = gigaspeech.dev_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
- if not params.print_diagnostics:
+ if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
diff --git a/egs/gigaspeech/KWS/RESULTS.md b/egs/gigaspeech/KWS/RESULTS.md
new file mode 100644
index 000000000..992240e14
--- /dev/null
+++ b/egs/gigaspeech/KWS/RESULTS.md
@@ -0,0 +1,49 @@
+# Results
+
+## zipformer transducer model
+
+This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details.
+
+The modeling units are 500 BPEs trained on gigaspeech transcripts.
+
+The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios).
+
+We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands.
+
+The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz).
+
+Here is the results of a small test set which has 20 commands, we list the results of every commands, for
+each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other
+for the finetune model finetuned on commands dataset.
+
+Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours |
+-- | -- | -- | -- | --| -- | -- | -- | --
+ | original | finetune | original | finetune | original | finetune | original | finetune
+All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6
+Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225
+Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025
+Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05
+Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0
+Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0
+Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1
+Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05
+Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0
+Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0
+Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025
+Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025
+Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0
+Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0
+Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0
+Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0
+Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075
+Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025
+
+This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here.
+
+Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
+-- | -- | -- | -- | -- | -- | -- | -- | --
+ | original | finetune | original | finetune | original | finetune | original | finetune
+All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3
diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh
new file mode 100755
index 000000000..0b098190d
--- /dev/null
+++ b/egs/gigaspeech/KWS/prepare.sh
@@ -0,0 +1,85 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=15
+stage=0
+stop_stage=100
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare gigaspeech dataset."
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.gigaspeech.done ]; then
+ pushd ../ASR
+ ./prepare.sh --stage 0 --stop-stage 9
+ ./prepare.sh --stage 11 --stop-stage 11
+ popd
+ pushd data/fbank
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/XL_split) .
+ ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
+ popd
+ pushd data
+ ln -svf $(realpath ../ASR/data/lang_bpe_500) .
+ popd
+ touch data/fbank/.gigaspeech.done
+ else
+ log "Gigaspeech dataset already exists, skipping."
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare open commands dataset."
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
+ pushd data
+ git clone https://github.com/pkufool/open-commands.git
+ ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
+ ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
+ pushd open-commands
+ ./script/prepare.sh --stage 2 --stop-stage 2
+ ./script/prepare.sh --stage 6 --stop-stage 6
+ popd
+ popd
+ pushd data/fbank
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
+ popd
+ touch data/fbank/.fluent_speech_commands.done
+ else
+ log "Fluent speech commands dataset already exists, skipping."
+ fi
+fi
diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh
new file mode 100755
index 000000000..bd562ce1c
--- /dev/null
+++ b/egs/gigaspeech/KWS/run.sh
@@ -0,0 +1,201 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+export PYTHONPATH=../../../:$PYTHONPATH
+
+stage=0
+stop_stage=100
+
+. shared/parse_options.sh || exit 1
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Train a model."
+ if [ ! -e data/fbank/.gigaspeech.done ]; then
+ log "You need to run the prepare.sh first."
+ exit -1
+ fi
+
+ python ./zipformer/train.py \
+ --world-size 4 \
+ --exp-dir zipformer/exp \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --num-epochs 12 \
+ --lr-epochs 1.5 \
+ --use-fp16 1 \
+ --start-epoch 1 \
+ --subset XL \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --max-duration 1000
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Decode the model."
+
+ export CUDA_VISIBLE_DEVICES="0"
+ for t in small large; do
+ python ./zipformer/decode.py \
+ --epoch 12 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --test-set $t \
+ --keywords-score 1.0 \
+ --keywords-threshold 0.35 \
+ --keywords-file ./data/commands_${t}.txt \
+ --max-duration 3000
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Export the model."
+
+ python ./zipformer/export.py \
+ --epoch 12 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128
+
+ python ./zipformer/export_onnx_streaming.py \
+ --exp-dir zipformer/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 12 \
+ --avg 2 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 2: Finetune the model"
+
+ # The following configuration of lr schedule should work well
+ # You may also tune the following parameters to adjust learning rate schedule
+ base_lr=0.0005
+ lr_epochs=100
+ lr_batches=100000
+
+ # We recommend to start from an averaged model
+ finetune_ckpt=zipformer/exp/pretrained.pt
+
+ ./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 10 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_finetune \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --use-fp16 1 \
+ --use-mux 1 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1 \
+ --base-lr $base_lr \
+ --lr-epochs $lr_epochs \
+ --lr-batches $lr_batches \
+ --finetune-ckpt $finetune_ckpt \
+ --max-duration 1500
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 1: Decode the finetuned model."
+ export CUDA_VISIBLE_DEVICES="0"
+ for t in small large; do
+ python ./zipformer/decode.py \
+ --epoch 10 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp_finetune \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --test-set $t \
+ --keywords-score 1.0 \
+ --keywords-threshold 0.35 \
+ --keywords-file ./data/commands_${t}.txt \
+ --max-duration 3000
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 2: Export the finetuned model."
+
+ python ./zipformer/export.py \
+ --epoch 10 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp_finetune \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128
+
+ python ./zipformer/export_onnx_streaming.py \
+ --exp-dir zipformer/exp_finetune \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 10 \
+ --avg 2 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1
+fi
diff --git a/egs/gigaspeech/KWS/shared b/egs/gigaspeech/KWS/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/gigaspeech/KWS/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..ccc602404
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
@@ -0,0 +1,477 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import glob
+import inspect
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import lhotse
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class GigaSpeechAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ # GigaSpeech specific arguments
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="XL",
+ help="Select the GigaSpeech subset (XS|S|M|L|XL)",
+ )
+ group.add_argument(
+ "--small-dev",
+ type=str2bool,
+ default=False,
+ help="Should we use only 1000 utterances for dev (speeds up training)",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info(f"About to get train {self.args.subset} cuts")
+ if self.args.subset == "XL":
+ filenames = glob.glob(
+ f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
+ )
+ pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
+ idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
+ idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
+ sorted_filenames = [f[1] for f in idx_filenames]
+ logging.info(
+ f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
+ )
+
+ cuts_train = lhotse.combine(
+ lhotse.load_manifest_lazy(p) for p in sorted_filenames
+ )
+ else:
+ path = (
+ self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
+ )
+ cuts_train = CutSet.from_jsonl_lazy(path)
+ return cuts_train
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
+ )
+ if self.args.small_dev:
+ return cuts_valid.subset(first=1000)
+ else:
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_train_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_valid_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands valid cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_test_small_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands small test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_test_large_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands large test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
+ )
diff --git a/egs/gigaspeech/KWS/zipformer/beam_search.py b/egs/gigaspeech/KWS/zipformer/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py
new file mode 100755
index 000000000..149b8bed0
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decode-asr.py
@@ -0,0 +1,1066 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from gigaspeech_scoring import asr_text_post_processing
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append(line.strip())
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(sp.encode(contexts))
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ test_cuts = gigaspeech.test_cuts()
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ test_fsc_cuts = gigaspeech.fsc_test_large_cuts()
+ test_fsc_dl = gigaspeech.test_dataloaders(test_fsc_cuts)
+
+ test_sets = ["test", "fsc_test"]
+ test_dls = [test_dl, test_fsc_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py
new file mode 100755
index 000000000..0df2ec356
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decode.py
@@ -0,0 +1,687 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --keywords-file keywords.txt \
+ --beam-size 4
+"""
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Dict, List, Optional, Set, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from beam_search import keywords_search
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+@dataclass
+class KwMetric:
+ TP: int = 0 # True positive
+ FN: int = 0 # False negative
+ FP: int = 0 # False positive
+ TN: int = 0 # True negative
+ FN_list: List[str] = field(default_factory=list)
+ FP_list: List[str] = field(default_factory=list)
+ TP_list: List[str] = field(default_factory=list)
+
+ def __str__(self) -> str:
+ return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--keywords-file",
+ type=str,
+ help="File contains keywords.",
+ )
+
+ parser.add_argument(
+ "--test-set",
+ type=str,
+ default="small",
+ help="small or large",
+ )
+
+ parser.add_argument(
+ "--keywords-score",
+ type=float,
+ default=1.5,
+ help="""
+ The default boosting score (token level) for keywords. it will boost the
+ paths that match keywords to make them survive beam search.
+ """,
+ )
+
+ parser.add_argument(
+ "--keywords-threshold",
+ type=float,
+ default=0.35,
+ help="The default threshold (probability) to trigger the keyword.",
+ )
+
+ parser.add_argument(
+ "--num-tailing-blanks",
+ type=int,
+ default=1,
+ help="The number of tailing blanks should have after hitting one keyword.",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ keywords_graph: Optional[ContextGraph] = None,
+) -> List[List[Tuple[str, Tuple[int, int]]]]:
+ """Decode one batch and return the result in a list.
+
+ The length of the list equals to batch size, the i-th element contains the
+ triggered keywords for the i-th utterance in the given batch. The triggered
+ keywords are also a list, each of it contains a tuple of hitting keyword and
+ the corresponding start timestamps and end timestamps of the hitting keyword.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ keywords_graph:
+ The graph containing keywords.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned list.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ ans_dict = keywords_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ keywords_graph=keywords_graph,
+ beam=params.beam,
+ num_tailing_blanks=params.num_tailing_blanks,
+ blank_penalty=params.blank_penalty,
+ )
+
+ hyps = []
+ for ans in ans_dict:
+ hyp = []
+ for hit in ans:
+ hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1])))
+ hyps.append(hyp)
+
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ keywords_graph: ContextGraph,
+ keywords: Set[str],
+ test_only_keywords: bool,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ keywords_graph:
+ The graph containing keywords.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 50
+
+ results = []
+ metric = {"all": KwMetric()}
+ for k in keywords:
+ metric[k] = KwMetric()
+
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ keywords_graph=keywords_graph,
+ batch=batch,
+ )
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_text = ref_text.upper()
+ ref_words = ref_text.split()
+ hyp_words = [x[0] for x in hyp_words]
+ # for computing WER
+ this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
+ hyp_set = set(hyp_words) # each item is a keyword phrase
+ if len(hyp_words) > 1:
+ logging.warning(
+ f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
+ f"please check the transcript to see if it really has more "
+ f"than one keywords, if so consider splitting this audio and"
+ f"keep only one keyword for each audio."
+ )
+ hyp_str = " | ".join(
+ hyp_words
+ ) # The triggered keywords for this utterance.
+ TP = False
+ FP = False
+ for x in hyp_set:
+ assert x in keywords, x # can only trigger keywords
+ if (test_only_keywords and x == ref_text) or (
+ not test_only_keywords and x in ref_text
+ ):
+ TP = True
+ metric[x].TP += 1
+ metric[x].TP_list.append(f"({ref_text} -> {x})")
+ if (test_only_keywords and x != ref_text) or (
+ not test_only_keywords and x not in ref_text
+ ):
+ FP = True
+ metric[x].FP += 1
+ metric[x].FP_list.append(f"({ref_text} -> {x})")
+ if TP:
+ metric["all"].TP += 1
+ if FP:
+ metric["all"].FP += 1
+ TN = True # all keywords are true negative then the summery is true negative.
+ FN = False
+ for x in keywords:
+ if x not in ref_text and x not in hyp_set:
+ metric[x].TN += 1
+ continue
+
+ TN = False
+ if (test_only_keywords and x == ref_text) or (
+ not test_only_keywords and x in ref_text
+ ):
+ fn = True
+ for y in hyp_set:
+ if (test_only_keywords and y == ref_text) or (
+ not test_only_keywords and y in ref_text
+ ):
+ fn = False
+ break
+ if fn:
+ FN = True
+ metric[x].FN += 1
+ metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
+ if TN:
+ metric["all"].TN += 1
+ if FN:
+ metric["all"].FN += 1
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results, metric
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+ metric: KwMetric,
+):
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
+
+ with open(metric_filename, "w") as of:
+ width = 10
+ for key, item in sorted(
+ metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
+ ):
+ acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
+ precision = (
+ 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
+ )
+ recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
+ fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
+ s = f"{key}:\n"
+ s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
+ s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
+ s += f"\tAccuracy: {acc:.3f}\n"
+ s += f"\tPrecision: {precision:.3f}\n"
+ s += f"\tRecall(PPR): {recall:.3f}\n"
+ s += f"\tFPR: {fpr:.3f}\n"
+ s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
+ if key != "all":
+ s += f"\tTP list: {' # '.join(item.TP_list)}\n"
+ s += f"\tFP list: {' # '.join(item.FP_list)}\n"
+ s += f"\tFN list: {' # '.join(item.FN_list)}\n"
+ of.write(s + "\n")
+ if key == "all":
+ logging.info(s)
+ of.write(f"\n\n{params.keywords_config}")
+
+ logging.info("Wrote metric stats to {}".format(metric_filename))
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "kws"
+
+ params.suffix = params.test_set
+ if params.iter > 0:
+ params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ params.suffix += f"-score-{params.keywords_score}"
+ params.suffix += f"-threshold-{params.keywords_threshold}"
+ params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
+ if params.blank_penalty != 0:
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+ params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ phrases = []
+ token_ids = []
+ keywords_scores = []
+ keywords_thresholds = []
+ keywords_config = []
+ with open(params.keywords_file, "r") as f:
+ for line in f.readlines():
+ keywords_config.append(line)
+ score = 0
+ threshold = 0
+ keyword = []
+ words = line.strip().upper().split()
+ for word in words:
+ word = word.strip()
+ if word[0] == ":":
+ score = float(word[1:])
+ continue
+ if word[0] == "#":
+ threshold = float(word[1:])
+ continue
+ keyword.append(word)
+ keyword = " ".join(keyword)
+ phrases.append(keyword)
+ token_ids.append(sp.encode(keyword))
+ keywords_scores.append(score)
+ keywords_thresholds.append(threshold)
+
+ params.keywords_config = "".join(keywords_config)
+
+ keywords_graph = ContextGraph(
+ context_score=params.keywords_score, ac_threshold=params.keywords_threshold
+ )
+ keywords_graph.build(
+ token_ids=token_ids,
+ phrases=phrases,
+ scores=keywords_scores,
+ ac_thresholds=keywords_thresholds,
+ )
+ keywords = set(phrases)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ test_cuts = gigaspeech.test_cuts()
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ if params.test_set == "small":
+ test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
+ test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
+ test_sets = ["small-fsc", "test"]
+ test_dls = [test_fsc_small_dl, test_dl]
+ else:
+ assert params.test_set == "large", params.test_set
+ test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
+ test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
+ test_sets = ["large-fsc", "test"]
+ test_dls = [test_fsc_large_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results, metric = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ keywords_graph=keywords_graph,
+ keywords=keywords,
+ test_only_keywords="fsc" in test_set,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results=results,
+ metric=metric,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/decoder.py b/egs/gigaspeech/KWS/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/encoder_interface.py b/egs/gigaspeech/KWS/zipformer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/export.py b/egs/gigaspeech/KWS/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py
new file mode 100755
index 000000000..a7ba56127
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/finetune.py
@@ -0,0 +1,642 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For non-streaming model training:
+./zipformer/finetune.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/fintune.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from train import (
+ add_model_arguments,
+ add_training_arguments,
+ compute_loss,
+ compute_validation_loss,
+ display_and_save_batch,
+ get_adjusted_batch_count,
+ get_model,
+ get_params,
+ load_checkpoint_if_available,
+ save_checkpoint,
+ scan_pessimistic_batches_for_oom,
+ set_batch_count,
+)
+
+from icefall import diagnostics
+from icefall.checkpoint import remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune.
+ """,
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (a path to a .pt file)",
+ )
+
+ parser.add_argument(
+ "--continue-finetune",
+ type=str2bool,
+ default=False,
+ help="Continue finetuning or finetune from pre-trained model",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ add_training_arguments(parser)
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+
+ # if params.continue_finetune:
+ # set_batch_count(model, params.batch_idx_train)
+ # else:
+ # set_batch_count(model, params.batch_idx_train + 100000)
+
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+
+ if params.continue_finetune:
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+ else:
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules
+ )
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_utt(c: Cut):
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ return T > 0
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ if params.use_mux:
+ train_cuts = CutSet.mux(
+ gigaspeech.train_cuts(),
+ gigaspeech.fsc_train_cuts(),
+ weights=[0.9, 0.1],
+ )
+ else:
+ train_cuts = gigaspeech.fsc_train_cuts()
+
+ train_cuts = train_cuts.filter(remove_short_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = gigaspeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = gigaspeech.fsc_valid_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
+ valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics and params.scan_for_oom_batches:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
new file mode 120000
index 000000000..4ee54fff5
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
@@ -0,0 +1 @@
+../../ASR/zipformer/gigaspeech_scoring.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/joiner.py b/egs/gigaspeech/KWS/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/model.py b/egs/gigaspeech/KWS/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/optim.py b/egs/gigaspeech/KWS/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/scaling.py b/egs/gigaspeech/KWS/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/subsampling.py b/egs/gigaspeech/KWS/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py
new file mode 100755
index 000000000..a4d670169
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/train.py
@@ -0,0 +1,1366 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="1,1,1,1,1,1",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="192,192,192,192,192,192",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="128,128,128,128,128,128",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="128,128,128,128,128,128",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=320,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=320,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=True,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ add_training_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_training_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=1,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-for-oom-batches",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to scan for oom batches before training, this is helpful for
+ finding the suitable max_duration, you only need to run it once.
+ Caution: a little time consuming.
+ """,
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 500,
+ "reset_interval": 2000,
+ "valid_interval": 20000,
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_utt(c: Cut):
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ return T > 0
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ train_cuts = gigaspeech.train_cuts()
+ train_cuts = train_cuts.filter(remove_short_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = gigaspeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = gigaspeech.dev_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
+ valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics and params.scan_for_oom_batches:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/zipformer.py b/egs/gigaspeech/KWS/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
index c1abdbdb5..500df9ea4 100644
--- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
+++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
@@ -256,6 +256,8 @@ class LibriCssAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py
index 6598f8b5d..90d742e7c 100755
--- a/egs/libricss/SURT/dprnn_zipformer/train.py
+++ b/egs/libricss/SURT/dprnn_zipformer/train.py
@@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1169,9 +1170,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
index 1c1b0c28c..8c37430ec 100755
--- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
+++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1056,9 +1057,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
index df761c1b8..4985f3f4c 100644
--- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py
+++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
@@ -232,7 +232,7 @@ class LibriHeavyAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
- CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@@ -310,6 +310,8 @@ class LibriHeavyAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py
index c97da4a11..8d4d9d067 100644
--- a/egs/libriheavy/ASR/zipformer/train.py
+++ b/egs/libriheavy/ASR/zipformer/train.py
@@ -93,6 +93,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1036,9 +1037,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
index 690003377..552f63905 100644
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
@@ -341,6 +341,8 @@ class LibriHeavyAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
@@ -423,9 +425,11 @@ class LibriHeavyAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
- input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
- if self.args.on_the_fly_feats
- else PrecomputedFeatures(),
+ input_strategy=(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures()
+ ),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
index c8b20d021..93f7e1248 100644
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
@@ -103,6 +103,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1051,9 +1052,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
index 9822b99c1..2a2c206aa 100755
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
@@ -117,6 +117,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
- context_dim=4 * 768
- if params.context_injection
- else -1, # the output dim of text encoder
+ context_dim=(
+ 4 * 768 if params.context_injection else -1
+ ), # the output dim of text encoder
context_injection=params.context_injection,
)
return joiner
@@ -1398,9 +1399,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librilight/SSL/zipformer/asr_datamodule.py b/egs/librilight/SSL/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..b9313bffc
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/beam_search.py b/egs/librilight/SSL/zipformer/beam_search.py
new file mode 120000
index 000000000..3b02c21db
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/dataset.py b/egs/librilight/SSL/zipformer/dataset.py
new file mode 120000
index 000000000..5cd60d3b4
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/dataset.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/dataset.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/decode.py b/egs/librilight/SSL/zipformer/decode.py
new file mode 100644
index 000000000..95643c5e1
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/decode.py
@@ -0,0 +1,1045 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(batch["supervisions"]["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["cuts"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ dev_clean_cuts = librispeech.dev_clean_cuts()
+ dev_other_cuts = librispeech.dev_other_cuts()
+
+ dev_clean_dl = librispeech.test_dataloaders(
+ dev_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ dev_other_dl = librispeech.test_dataloaders(
+ dev_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(
+ test_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ test_other_dl = librispeech.test_dataloaders(
+ test_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
+ test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
+ # test_sets = ["dev-clean", "dev-other"]
+ # test_dl = [dev_clean_dl, dev_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librilight/SSL/zipformer/decoder.py b/egs/librilight/SSL/zipformer/decoder.py
new file mode 120000
index 000000000..96dbfc5cd
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/encoder_interface.py b/egs/librilight/SSL/zipformer/encoder_interface.py
new file mode 120000
index 000000000..30859c51b
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py
new file mode 100644
index 000000000..50dbd5f2d
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/finetune.py
@@ -0,0 +1,1552 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For HuBERT model finetuning:
+./hubert/finetune.py \
+ --world-size 8 \
+ --num-epochs 200 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 0 \
+ --max-duration 1000
+
+It supports finetuning with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from hubert_ce import HubertModel
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ # hubert parameters
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias", type=bool, default=False, help="include bias in conv encoder"
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=222,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--pretrained-dir",
+ type=str,
+ help="""The pretrained model dir.
+ It specifies the directory where the pretrained checkpoint is saved.""",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=1,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for pruned RNN-T loss
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ if hasattr(params, "pretrained_dir"):
+ logging.info(f"Loading {params.pretrained_dir}")
+ pretrained = torch.load(params.pretrained_dir)
+ encoder = HubertModel(params)
+ encoder.load_state_dict(pretrained["model"])
+ else:
+ encoder = HubertModel(params)
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss, num_frames = model(
+ x=audio,
+ padding_mask=padding_mask,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_frames.sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=0)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+ y = sp.encode(batch["supervisions"]["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librilight/SSL/zipformer/hubert_ce.py b/egs/librilight/SSL/zipformer/hubert_ce.py
new file mode 120000
index 000000000..2b8482f78
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/hubert_ce.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/hubert_ce.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/joiner.py b/egs/librilight/SSL/zipformer/joiner.py
new file mode 120000
index 000000000..587823e65
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/model.py b/egs/librilight/SSL/zipformer/model.py
new file mode 120000
index 000000000..ca3daacca
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/model.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/optim.py b/egs/librilight/SSL/zipformer/optim.py
new file mode 120000
index 000000000..bd2153ebf
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/pretrain.py b/egs/librilight/SSL/zipformer/pretrain.py
new file mode 100644
index 000000000..5728dbe75
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/pretrain.py
@@ -0,0 +1,1366 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For hubert model pretraining:
+./zipformer/pretrain.py \
+ --world-size 8 \
+ --num-epochs 400 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 87.5 \
+ --accum-grad 4
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from hubert_ce import HubertModel
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from ssl_datamodule import LibriLightDataModule
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ # hubert parameters
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias", type=bool, default=False, help="include bias in conv encoder"
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=400,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--warmup-batches",
+ type=float,
+ default=5000,
+ help="Eden warmup steps",
+ )
+
+ parser.add_argument(
+ "--warmup-start",
+ type=float,
+ default=0,
+ help="Eden warmup start learning rate",
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=4,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--max-sample-size",
+ type=float,
+ default=250000,
+ help="max sample size",
+ )
+
+ parser.add_argument(
+ "--min-sample-size",
+ type=float,
+ default=32000,
+ help="min sample size",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ model = HubertModel(params)
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+ kmeans = batch["kmeans"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss, num_masked_tokens, logging_output = model(
+ source=audio, target_list=[kmeans], padding_mask=padding_mask
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_masked_tokens
+ for item in logging_output:
+ info[item] = logging_output[item]
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = batch["kmeans"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(
+ optimizer,
+ params.lr_batches,
+ params.lr_epochs,
+ params.warmup_batches,
+ params.warmup_start,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librilight = LibriLightDataModule(args)
+
+ train_cuts = librilight.train_all_shuf_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if (
+ c.duration < params.min_sample_size / params.sample_rate
+ or c.duration > params.max_sample_size / params.sample_rate
+ ):
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librilight.train_dataloaders(
+ train_cuts,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librilight.dev_clean_cuts()
+ # valid_cuts += librilight.dev_other_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = librilight.valid_dataloaders(
+ valid_cuts,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriLightDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librilight/SSL/zipformer/scaling.py b/egs/librilight/SSL/zipformer/scaling.py
new file mode 120000
index 000000000..24b661dfb
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/ssl_datamodule.py b/egs/librilight/SSL/zipformer/ssl_datamodule.py
new file mode 100644
index 000000000..dc0dbec6c
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/ssl_datamodule.py
@@ -0,0 +1,334 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import glob
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from dataset import HubertDataset
+from lhotse import CutSet, combine, load_manifest_lazy
+from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriLightDataModule:
+ """
+ DataModule for SSL experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in SSL
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+
+ This class should be derived for specific corpora used in SSL tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR SSL related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/kmeans"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=float,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+ group.add_argument(
+ "--do-normalize",
+ type=str2bool,
+ default=True,
+ help="whether to normalize the data",
+ )
+ group.add_argument(
+ "--random-crop",
+ type=str2bool,
+ default=True,
+ help="audio sample rate",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = HubertDataset(
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(
+ self,
+ cuts_valid: CutSet,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ ) -> DataLoader:
+ logging.info("About to create dev dataset")
+ validate = HubertDataset(
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(
+ self,
+ cuts: CutSet,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ ) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = HubertDataset(
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def small_cuts(self) -> CutSet:
+ logging.info("About to get small cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
+ )
+
+ @lru_cache()
+ def medium_cuts(self) -> CutSet:
+ logging.info("About to get medium cuts")
+ filenames = glob.glob(
+ f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
+ )
+ pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
+ idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
+ idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
+ sorted_filenames = [f[1] for f in idx_filenames]
+ logging.info(
+ f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
+ )
+
+ return combine(load_manifest_lazy(p) for p in sorted_filenames)
+
+ @lru_cache()
+ def large_cuts(self) -> CutSet:
+ logging.info("About to get large cuts")
+ filenames = glob.glob(
+ f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
+ )
+ pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
+ idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
+ idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
+ sorted_filenames = [f[1] for f in idx_filenames]
+ logging.info(
+ f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
+ )
+
+ return combine(load_manifest_lazy(p) for p in sorted_filenames)
+
+ @lru_cache()
+ def train_all_shuf_cuts(self) -> CutSet:
+ logging.info("About to get the shuffled small, medium and large cuts")
+ small_cuts = self.small_cuts()
+ medium_cuts = self.medium_cuts()
+ large_cuts = self.large_cuts()
+ return CutSet.mux(
+ small_cuts,
+ medium_cuts,
+ large_cuts,
+ weights=[
+ 122867, # len(small_cuts)
+ 1104071, # len(medium_cuts)
+ 11012085, # len(large_cuts)
+ ],
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
diff --git a/egs/librilight/SSL/zipformer/utils.py b/egs/librilight/SSL/zipformer/utils.py
new file mode 120000
index 000000000..119992bdb
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/utils.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/utils.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/wav2vec2_module.py b/egs/librilight/SSL/zipformer/wav2vec2_module.py
new file mode 120000
index 000000000..81ad701e4
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/wav2vec2_module.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/wav2vec2_module.py
\ No newline at end of file
diff --git a/egs/librilight/SSL/zipformer/zipformer.py b/egs/librilight/SSL/zipformer/zipformer.py
new file mode 120000
index 000000000..5b3da8cd5
--- /dev/null
+++ b/egs/librilight/SSL/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/SSL/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index 1c8930818..080f81c91 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -35,6 +35,8 @@ The following table lists the differences among them.
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
+| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
+| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index ebf5e89c4..ee5422aba 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1526,7 +1526,7 @@ done
You may also decode using LODR + LM shallow fusion. This decoding method is proposed in .
It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
-generated by `generate-lm.sh`, or you may download it from .
+generated by `prepare_lm.sh` at stage 4, or you may download it from .
The decoding command is as follows:
diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py
index b97b7eed8..aa12d6f83 100755
--- a/egs/librispeech/ASR/conformer_ctc3/test_model.py
+++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py
@@ -24,8 +24,7 @@ To run this file, do:
"""
import torch
-
-from train import get_params, get_ctc_model
+from train import get_ctc_model, get_params
def test_model():
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
index 1e59e0858..79728afa4 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
@@ -59,9 +59,9 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
-from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
index 58f587c91..1deecbfc7 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
@@ -39,7 +39,7 @@ Usage of this script:
import argparse
import logging
import math
-from typing import List
+from typing import List, Optional
import kaldifeat
import sentencepiece as spm
@@ -47,7 +47,6 @@ import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence
-from typing import Optional, List
def get_parser():
diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh
deleted file mode 100755
index dacd276d1..000000000
--- a/egs/librispeech/ASR/generate-lm.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/usr/bin/env bash
-
-lang_dir=data/lang_bpe_500
-
-for ngram in 2 3 4 5; do
- if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
- ./shared/make_kn_lm.py \
- -ngram-order ${ngram} \
- -text $lang_dir/transcript_tokens.txt \
- -lm $lang_dir/${ngram}gram.arpa
- fi
-
- if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
- python3 -m kaldilm \
- --read-symbol-table="$lang_dir/tokens.txt" \
- --disambig-symbol='#0' \
- --max-order=${ngram} \
- $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
- fi
-done
diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py
index 62036467e..d7781687f 100755
--- a/egs/librispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/librispeech/ASR/local/compute_fbank_musan.py
@@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
-
+import argparse
import logging
import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ MonoCut,
+ WhisperFbank,
+ WhisperFbankConfig,
+ combine,
+)
from lhotse.recipes.utils import read_manifests_if_cached
-from icefall.utils import get_executor
+from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@@ -45,11 +54,12 @@ def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
-def compute_fbank_musan():
+def compute_fbank_musan(
+ num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank"
+):
src_dir = Path("data/manifests")
- output_dir = Path("data/fbank")
+ output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
- num_mel_bins = 80
dataset_parts = (
"music",
@@ -81,7 +91,12 @@ def compute_fbank_musan():
logging.info("Extracting features for Musan")
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
@@ -102,8 +117,36 @@ def compute_fbank_musan():
musan_cuts.to_file(musan_cuts_path)
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory. Default: data/fbank.",
+ )
+ return parser.parse_args()
+
+
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- compute_fbank_musan()
+ args = get_args()
+ compute_fbank_musan(
+ num_mel_bins=args.num_mel_bins,
+ whisper_fbank=args.whisper_fbank,
+ output_dir=args.output_dir,
+ )
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 43142aee4..5979d5b98 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -28,6 +28,7 @@
import argparse
import shutil
from pathlib import Path
+from typing import Dict
import sentencepiece as spm
@@ -57,6 +58,18 @@ def get_args():
return parser.parse_args()
+def generate_tokens(lang_dir: Path):
+ """
+ Generate the tokens.txt from a bpe model.
+ """
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(lang_dir / "bpe.model"))
+ token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
+ with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
+ for sym, i in token2id.items():
+ f.write(f"{sym} {i}\n")
+
+
def main():
args = get_args()
vocab_size = args.vocab_size
@@ -95,6 +108,8 @@ def main():
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+ generate_tokens(lang_dir)
+
if __name__ == "__main__":
main()
diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py
index 466253446..f4008c23b 100755
--- a/egs/librispeech/ASR/long_file_recog/recognize.py
+++ b/egs/librispeech/ASR/long_file_recog/recognize.py
@@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
"""
import argparse
-import torch.multiprocessing as mp
-import torch
-import torch.nn as nn
import logging
from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional, Tuple
-
from pathlib import Path
+from typing import List, Optional, Tuple
import k2
import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
from asr_datamodule import AsrDataModule
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
-from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
-from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter
+from lhotse.supervision import AlignmentItem
+
+from icefall.utils import AttributeDict, convert_timestamp, setup_logger
def get_parser():
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
index 2a52e2eec..1ce770128 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
@@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx-zh.py \
- --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
+ --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
--use-averaged-model 1 \
--epoch 11 \
--avg 1 \
@@ -55,6 +55,7 @@ import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
+import k2
import onnx
import torch
import torch.nn as nn
@@ -70,8 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -441,9 +441,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
index 5712da25e..aeed58dec 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
@@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
decoder_filename:
The filename to save the exported model.
"""
- y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
- need_pad = torch.tensor([False])
-
- traced_model = torch.jit.trace(decoder_model, (y, need_pad))
+ # TODO(fangjun): Change the function name since we are actually using
+ # torch.jit.script instead of torch.jit.trace
+ traced_model = torch.jit.script(decoder_model)
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
index c83f38b2a..85e0648d3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_check.py
@@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
+import torch
from onnx_pretrained import OnnxModel
from icefall import is_module_available
-import torch
-
def get_parser():
parser = argparse.ArgumentParser(
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 4a5072cc0..40dc3260d 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -6,8 +6,21 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
-stage=-1
-stop_stage=100
+# run step 0 to step 5 by default
+stage=0
+stop_stage=5
+
+# Note: This script just prepare the minimal requirements that needed by a
+# transducer training with bpe units.
+#
+# If you want to use ngram or nnlm, please continue running prepare_lm.sh after
+# you succeed running this script.
+#
+# This script also contains the steps to generate phone based units, but they
+# will not run automatically, you can generate the phone based units by
+# bash prepare.sh --stage -1 --stop-stage -1
+# bash prepare.sh --stage 6 --stop-stage 6
+
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@@ -17,6 +30,18 @@ stop_stage=100
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+#
+# lm directory is not necessary for transducer training with bpe units, but it
+# is needed by phone based modeling, you can download it by running
+# bash prepare.sh --stage -1 --stop-stage -1
+# then you can see the following files in the directory.
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
@@ -28,14 +53,7 @@ stop_stage=100
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
-#
-# - $dl_dir/musan
-# This directory contains the following directories downloaded from
-# http://www.openslr.org/17/
-#
-# - music
-# - noise
-# - speech
+
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
@@ -60,6 +78,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
+log "Running prepare.sh"
+
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
@@ -159,13 +179,49 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
- log "Stage 5: Prepare phone based lang"
+ log "Stage 5: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ files=$(
+ find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+ done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
- (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
- cat - $dl_dir/lm/librispeech-lexicon.txt |
- sort | uniq > $lang_dir/lexicon.txt
+ if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then
+ log "No lexicon file in $dl_dir/lm, please run :"
+ log "prepare.sh --stage -1 --stop-stage -1"
+ exit -1
+ fi
+
+ if [ ! -f $lang_dir/lexicon.txt ]; then
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/librispeech-lexicon.txt |
+ sort | uniq > $lang_dir/lexicon.txt
+ fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
@@ -187,253 +243,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
$lang_dir/L_disambig.fst
fi
fi
-
-
-if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
- log "Stage 6: Prepare BPE based lang"
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- mkdir -p $lang_dir
- # We reuse words.txt from phone based lexicon
- # so that the two can share G.pt later.
- cp data/lang_phone/words.txt $lang_dir
-
- if [ ! -f $lang_dir/transcript_words.txt ]; then
- log "Generate data for BPE training"
- files=$(
- find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $lang_dir/transcript_words.txt
- fi
-
- if [ ! -f $lang_dir/bpe.model ]; then
- ./local/train_bpe_model.py \
- --lang-dir $lang_dir \
- --vocab-size $vocab_size \
- --transcript $lang_dir/transcript_words.txt
- fi
-
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang_bpe.py --lang-dir $lang_dir
-
- log "Validating $lang_dir/lexicon.txt"
- ./local/validate_bpe_lexicon.py \
- --lexicon $lang_dir/lexicon.txt \
- --bpe-model $lang_dir/bpe.model
- fi
-
- if [ ! -f $lang_dir/L.fst ]; then
- log "Converting L.pt to L.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L.pt \
- $lang_dir/L.fst
- fi
-
- if [ ! -f $lang_dir/L_disambig.fst ]; then
- log "Converting L_disambig.pt to L_disambig.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L_disambig.pt \
- $lang_dir/L_disambig.fst
- fi
- done
-fi
-
-if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
- log "Stage 7: Prepare bigram token-level P for MMI training"
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
-
- if [ ! -f $lang_dir/transcript_tokens.txt ]; then
- ./local/convert_transcript_words_to_tokens.py \
- --lexicon $lang_dir/lexicon.txt \
- --transcript $lang_dir/transcript_words.txt \
- --oov "" \
- > $lang_dir/transcript_tokens.txt
- fi
-
- if [ ! -f $lang_dir/P.arpa ]; then
- ./shared/make_kn_lm.py \
- -ngram-order 2 \
- -text $lang_dir/transcript_tokens.txt \
- -lm $lang_dir/P.arpa
- fi
-
- if [ ! -f $lang_dir/P.fst.txt ]; then
- python3 -m kaldilm \
- --read-symbol-table="$lang_dir/tokens.txt" \
- --disambig-symbol='#0' \
- --max-order=2 \
- $lang_dir/P.arpa > $lang_dir/P.fst.txt
- fi
- done
-fi
-
-if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
- log "Stage 8: Prepare G"
- # We assume you have installed kaldilm, if not, please install
- # it using: pip install kaldilm
-
- mkdir -p data/lm
- if [ ! -f data/lm/G_3_gram.fst.txt ]; then
- # It is used in building HLG
- python3 -m kaldilm \
- --read-symbol-table="data/lang_phone/words.txt" \
- --disambig-symbol='#0' \
- --max-order=3 \
- $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
- fi
-
- if [ ! -f data/lm/G_4_gram.fst.txt ]; then
- # It is used for LM rescoring
- python3 -m kaldilm \
- --read-symbol-table="data/lang_phone/words.txt" \
- --disambig-symbol='#0' \
- --max-order=4 \
- $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
- fi
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
-
- if [ ! -f $lang_dir/HL.fst ]; then
- ./local/prepare_lang_fst.py \
- --lang-dir $lang_dir \
- --ngram-G ./data/lm/G_3_gram.fst.txt
- fi
- done
-fi
-
-if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
- log "Stage 9: Compile HLG"
- ./local/compile_hlg.py --lang-dir data/lang_phone
-
- # Note If ./local/compile_hlg.py throws OOM,
- # please switch to the following command
- #
- # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/compile_hlg.py --lang-dir $lang_dir
-
- # Note If ./local/compile_hlg.py throws OOM,
- # please switch to the following command
- #
- # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
- done
-fi
-
-# Compile LG for RNN-T fast_beam_search decoding
-if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
- log "Stage 10: Compile LG"
- ./local/compile_lg.py --lang-dir data/lang_phone
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/compile_lg.py --lang-dir $lang_dir
- done
-fi
-
-if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
- log "Stage 11: Generate LM training data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- lang_dir=data/lang_bpe_${vocab_size}
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $dl_dir/lm/librispeech-lm-norm.txt \
- --lm-archive $out_dir/lm_data.pt
- done
-fi
-
-if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 12: Generate LM validation data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- if [ ! -f $out_dir/valid.txt ]; then
- files=$(
- find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $out_dir/valid.txt
- fi
-
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $out_dir/valid.txt \
- --lm-archive $out_dir/lm_data-valid.pt
- done
-fi
-
-if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
- log "Stage 13: Generate LM test data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- if [ ! -f $out_dir/test.txt ]; then
- files=$(
- find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $out_dir/test.txt
- fi
-
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $out_dir/test.txt \
- --lm-archive $out_dir/lm_data-test.pt
- done
-fi
-
-if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
- log "Stage 14: Sort LM training data"
- # Sort LM training data by sentence length in descending order
- # for ease of training.
- #
- # Sentence length equals to the number of BPE tokens
- # in a sentence.
-
- for vocab_size in ${vocab_sizes[@]}; do
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data.pt \
- --out-lm-data $out_dir/sorted_lm_data.pt \
- --out-statistics $out_dir/statistics.txt
-
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data-valid.pt \
- --out-lm-data $out_dir/sorted_lm_data-valid.pt \
- --out-statistics $out_dir/statistics-valid.txt
-
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data-test.pt \
- --out-lm-data $out_dir/sorted_lm_data-test.pt \
- --out-statistics $out_dir/statistics-test.txt
- done
-fi
diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh
new file mode 100755
index 000000000..a8eb5ca78
--- /dev/null
+++ b/egs/librispeech/ASR/prepare_lm.sh
@@ -0,0 +1,262 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+# This script generate Ngram LM / NNLM and related files that needed by decoding.
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/lm
+# This directory contains the following files downloaded from
+# http://www.openslr.org/resources/11
+#
+# - 3-gram.pruned.1e-7.arpa.gz
+# - 3-gram.pruned.1e-7.arpa
+# - 4-gram.arpa.gz
+# - 4-gram.arpa
+# - librispeech-vocab.txt
+# - librispeech-lexicon.txt
+# - librispeech-lm-norm.txt.gz
+#
+
+. prepare.sh --stage -1 --stop-stage 6 || exit 1
+
+log "Running prepare_lm.sh"
+
+stage=0
+stop_stage=100
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare BPE based lexicon."
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ # We reuse words.txt from phone based lexicon
+ # so that the two can share G.pt later.
+ cp data/lang_phone/words.txt $lang_dir
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
+
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare word level G"
+ # We assume you have installed kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ mkdir -p data/lm
+ if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
+ fi
+
+ if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+ # It is used for LM rescoring
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=4 \
+ $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
+ fi
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/HL.fst ]; then
+ ./local/prepare_lang_fst.py \
+ --lang-dir $lang_dir \
+ --ngram-G ./data/lm/G_3_gram.fst.txt
+ fi
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compile HLG"
+ ./local/compile_hlg.py --lang-dir data/lang_phone
+
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir
+
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+ done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Compile LG"
+ ./local/compile_lg.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/compile_lg.py --lang-dir $lang_dir
+ done
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare token level ngram G"
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+ ./local/convert_transcript_words_to_tokens.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --transcript $lang_dir/transcript_words.txt \
+ --oov "" \
+ > $lang_dir/transcript_tokens.txt
+ fi
+
+ for ngram in 2 3 4 5; do
+ if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order ${ngram} \
+ -text $lang_dir/transcript_tokens.txt \
+ -lm $lang_dir/${ngram}gram.arpa
+ fi
+
+ if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/tokens.txt" \
+ --disambig-symbol='#0' \
+ --max-order=${ngram} \
+ $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
+ fi
+ done
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate NNLM training data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ lang_dir=data/lang_bpe_${vocab_size}
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $dl_dir/lm/librispeech-lm-norm.txt \
+ --lm-archive $out_dir/lm_data.pt
+ done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Generate NNLM validation data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ if [ ! -f $out_dir/valid.txt ]; then
+ files=$(
+ find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $out_dir/valid.txt
+ fi
+
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $out_dir/valid.txt \
+ --lm-archive $out_dir/lm_data-valid.pt
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Generate NNLM test data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ if [ ! -f $out_dir/test.txt ]; then
+ files=$(
+ find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $out_dir/test.txt
+ fi
+
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $out_dir/test.txt \
+ --lm-archive $out_dir/lm_data-test.pt
+ done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Sort NNLM training data"
+ # Sort LM training data by sentence length in descending order
+ # for ease of training.
+ #
+ # Sentence length equals to the number of BPE tokens
+ # in a sentence.
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data.pt \
+ --out-lm-data $out_dir/sorted_lm_data.pt \
+ --out-statistics $out_dir/statistics.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data-valid.pt \
+ --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+ --out-statistics $out_dir/statistics-valid.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data-test.pt \
+ --out-lm-data $out_dir/sorted_lm_data-test.pt \
+ --out-statistics $out_dir/statistics-test.txt
+ done
+fi
diff --git a/egs/librispeech/ASR/prepare_mmi.sh b/egs/librispeech/ASR/prepare_mmi.sh
new file mode 100755
index 000000000..d8a6e0caf
--- /dev/null
+++ b/egs/librispeech/ASR/prepare_mmi.sh
@@ -0,0 +1,45 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+
+. prepare.sh --stage -1 --stop-stage 6 || exit 1
+
+log "Running prepare_mmi.sh"
+
+stage=0
+stop_stage=100
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare bigram token-level P for MMI training"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+ ./local/convert_transcript_words_to_tokens.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --transcript $lang_dir/transcript_words.txt \
+ --oov "" \
+ > $lang_dir/transcript_tokens.txt
+ fi
+
+ if [ ! -f $lang_dir/P.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order 2 \
+ -text $lang_dir/transcript_tokens.txt \
+ -lm $lang_dir/P.arpa
+ fi
+
+ if [ ! -f $lang_dir/P.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/tokens.txt" \
+ --disambig-symbol='#0' \
+ --max-order=2 \
+ $lang_dir/P.arpa > $lang_dir/P.fst.txt
+ fi
+ done
+fi
diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
index ee7556e49..be36c06b6 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
@@ -286,6 +286,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
bucket_method="equal_duration",
drop_last=True,
)
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
index ec2c9d580..e42a5c6ef 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
@@ -159,6 +159,7 @@ def main():
# Load id of the token and the vocab size
params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
index 03847b449..b961611f7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
@@ -91,7 +91,7 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
- embedding_out = self.embedding(y)
+ embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
index b844ba613..9762d878c 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
@@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
import argparse
import logging
+
import sentencepiece as spm
import torch
+from train import add_model_arguments, get_encoder_model, get_params
from icefall.profiler import get_model_profile
-from train import get_encoder_model, add_model_arguments, get_params
def get_parser():
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py
index 8134d43f8..a235d7b13 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/onnx_decode.py
@@ -75,8 +75,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
-
-from onnx_pretrained import greedy_search, OnnxModel
+from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 7fcd242fc..66c84b2a9 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
@@ -31,6 +32,7 @@ from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import (
DecodingResults,
+ KeywordResult,
add_eos,
add_sos,
get_texts,
@@ -789,6 +791,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor
+ ac_probs: Optional[List[float]] = None
+
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
@@ -805,6 +809,8 @@ class Hypothesis:
# Context graph state
context_state: Optional[ContextState] = None
+ num_tailing_blanks: int = 0
+
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@@ -953,6 +959,241 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
return ans
+def keywords_search(
+ model: nn.Module,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ keywords_graph: ContextGraph,
+ beam: int = 4,
+ num_tailing_blanks: int = 0,
+ blank_penalty: float = 0,
+) -> List[List[KeywordResult]]:
+ """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
+
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ Output from the encoder. Its shape is (N, T, C).
+ encoder_out_lens:
+ A 1-D tensor of shape (N,), containing number of valid frames in
+ encoder_out before padding.
+ keywords_graph:
+ A instance of ContextGraph containing keywords and their configurations.
+ beam:
+ Number of active paths during the beam search.
+ num_tailing_blanks:
+ The number of tailing blanks a keyword should be followed, this is for the
+ scenario that a keyword will be the prefix of another. In most cases, you
+ can just set it to 0.
+ blank_penalty:
+ The score used to penalize blank probability.
+ Returns:
+ Return a list of list of KeywordResult.
+ """
+ assert encoder_out.ndim == 3, encoder_out.shape
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+ assert keywords_graph is not None
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ blank_id = model.decoder.blank_id
+ unk_id = getattr(model, "unk_id", blank_id)
+ context_size = model.decoder.context_size
+ device = next(model.parameters()).device
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ B = [HypothesisList() for _ in range(N)]
+ for i in range(N):
+ B[i].add(
+ Hypothesis(
+ ys=[-1] * (context_size - 1) + [blank_id],
+ log_prob=torch.zeros(1, dtype=torch.float32, device=device),
+ context_state=keywords_graph.root,
+ timestamp=[],
+ ac_probs=[],
+ )
+ )
+
+ encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
+
+ offset = 0
+ finalized_B = []
+ sorted_ans = [[] for _ in range(N)]
+ for t, batch_size in enumerate(batch_size_list):
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
+ # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
+ offset = end
+
+ finalized_B = B[batch_size:] + finalized_B
+ B = B[:batch_size]
+
+ hyps_shape = get_hyps_shape(B).to(device)
+
+ A = [list(b) for b in B]
+
+ B = [HypothesisList() for _ in range(batch_size)]
+
+ ys_log_probs = torch.cat(
+ [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
+ ) # (num_hyps, 1)
+
+ decoder_input = torch.tensor(
+ [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
+ device=device,
+ dtype=torch.int64,
+ ) # (num_hyps, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+ # decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
+
+ # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
+ # as index, so we use `to(torch.int64)` below.
+ current_encoder_out = torch.index_select(
+ current_encoder_out,
+ dim=0,
+ index=hyps_shape.row_ids(1).to(torch.int64),
+ ) # (num_hyps, 1, 1, encoder_out_dim)
+
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ project_input=False,
+ ) # (num_hyps, 1, 1, vocab_size)
+
+ logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
+
+ if blank_penalty != 0:
+ logits[:, 0] -= blank_penalty
+
+ probs = logits.softmax(dim=-1) # (num_hyps, vocab_size)
+
+ log_probs = probs.log()
+
+ probs = probs.reshape(-1)
+
+ log_probs.add_(ys_log_probs)
+
+ vocab_size = log_probs.size(-1)
+
+ log_probs = log_probs.reshape(-1)
+
+ row_splits = hyps_shape.row_splits(1) * vocab_size
+ log_probs_shape = k2.ragged.create_ragged_shape2(
+ row_splits=row_splits, cached_tot_size=log_probs.numel()
+ )
+ ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+ ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs)
+
+ for i in range(batch_size):
+ topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+ hyp_probs = ragged_probs[i].tolist()
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+ topk_token_indexes = (topk_indexes % vocab_size).tolist()
+
+ for k in range(len(topk_hyp_indexes)):
+ hyp_idx = topk_hyp_indexes[k]
+ hyp = A[i][hyp_idx]
+ new_ys = hyp.ys[:]
+ new_token = topk_token_indexes[k]
+ new_timestamp = hyp.timestamp[:]
+ new_ac_probs = hyp.ac_probs[:]
+ context_score = 0
+ new_context_state = hyp.context_state
+ new_num_tailing_blanks = hyp.num_tailing_blanks + 1
+ if new_token not in (blank_id, unk_id):
+ new_ys.append(new_token)
+ new_timestamp.append(t)
+ new_ac_probs.append(hyp_probs[topk_indexes[k]])
+ (
+ context_score,
+ new_context_state,
+ _,
+ ) = keywords_graph.forward_one_step(hyp.context_state, new_token)
+ new_num_tailing_blanks = 0
+ if new_context_state.token == -1: # root
+ new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
+
+ new_log_prob = topk_log_probs[k] + context_score
+
+ new_hyp = Hypothesis(
+ ys=new_ys,
+ log_prob=new_log_prob,
+ timestamp=new_timestamp,
+ ac_probs=new_ac_probs,
+ context_state=new_context_state,
+ num_tailing_blanks=new_num_tailing_blanks,
+ )
+ B[i].add(new_hyp)
+
+ top_hyp = B[i].get_most_probable(length_norm=True)
+ matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
+ if matched:
+ ac_prob = (
+ sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
+ )
+ if (
+ matched
+ and top_hyp.num_tailing_blanks > num_tailing_blanks
+ and ac_prob >= matched_state.ac_threshold
+ ):
+ keyword = KeywordResult(
+ hyps=top_hyp.ys[-matched_state.level :],
+ timestamps=top_hyp.timestamp[-matched_state.level :],
+ phrase=matched_state.phrase,
+ )
+ sorted_ans[i].append(keyword)
+ B[i] = HypothesisList()
+ B[i].add(
+ Hypothesis(
+ ys=[-1] * (context_size - 1) + [blank_id],
+ log_prob=torch.zeros(1, dtype=torch.float32, device=device),
+ context_state=keywords_graph.root,
+ timestamp=[],
+ ac_probs=[],
+ )
+ )
+
+ B = B + finalized_B
+
+ for i, hyps in enumerate(B):
+ top_hyp = hyps.get_most_probable(length_norm=True)
+ matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
+ if matched:
+ ac_prob = (
+ sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
+ )
+ if matched and ac_prob >= matched_state.ac_threshold:
+ keyword = KeywordResult(
+ hyps=top_hyp.ys[-matched_state.level :],
+ timestamps=top_hyp.timestamp[-matched_state.level :],
+ phrase=matched_state.phrase,
+ )
+ sorted_ans[i].append(keyword)
+
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+ return ans
+
+
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
index 057624272..87c62789e 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
@@ -223,6 +223,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
@@ -256,6 +258,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=False,
)
logging.info("About to create dev dataloader")
@@ -282,6 +286,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
index 5ca4173c1..e2c1d6b5b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py
@@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
-from icefall import is_module_available
+import torch
from onnx_pretrained import OnnxModel
-import torch
+from icefall import is_module_available
def get_parser():
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py
index 3b1c72cf1..f8fed9519 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_decode.py
@@ -76,8 +76,7 @@ import torch
import torch.nn as nn
from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech
-
-from onnx_pretrained import greedy_search, OnnxModel
+from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
index 4bf773918..cf0598ca3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
@@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
import argparse
import logging
+from typing import Tuple
+
import sentencepiece as spm
import torch
-
-from typing import Tuple
+from scaling import BasicNorm, DoubleSwish
from torch import Tensor, nn
+from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile
-from scaling import BasicNorm, DoubleSwish
-from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser():
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py
index 6f26e34b5..b0f76317b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_decode.py
@@ -82,8 +82,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
-
-from onnx_pretrained import greedy_search, OnnxModel
+from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
index bfb5fe609..ee8196c3f 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
@@ -20,7 +20,6 @@ from typing import List
import k2
import torch
-
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
# The force alignment problem can be formulated as finding
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
index b0e4be0d1..7095c3cc8 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
@@ -107,9 +107,6 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
-
-# from asr_datamodule import LibriSpeechAsrDataModule
-from gigaspeech import GigaSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@@ -120,6 +117,9 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
+
+# from asr_datamodule import LibriSpeechAsrDataModule
+from gigaspeech import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_params, get_transducer_model
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
index a7a8ef149..e7546ec45 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -976,9 +977,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py
index 37edc0390..3fd14aa47 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/generate_model_from_checkpoint.py
@@ -65,16 +65,15 @@ from typing import Dict, List
import sentencepiece as spm
import torch
-
from train import add_model_arguments, get_params, get_transducer_model
-from icefall.utils import str2bool
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
+from icefall.utils import str2bool
def get_parser():
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
index cd432fd6f..306f30c2f 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
@@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
index 5a068b3b6..1416c6828 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
@@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
import argparse
import logging
+from typing import Tuple
+
import sentencepiece as spm
import torch
-
-from typing import Tuple
+from scaling import BasicNorm, DoubleSwish
from torch import Tensor, nn
+from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile
-from scaling import BasicNorm, DoubleSwish
-from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser():
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py
index 67585ee47..e00281239 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/onnx_decode.py
@@ -75,8 +75,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
-
-from onnx_pretrained import greedy_search, OnnxModel
+from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
index cdf914df3..1f50eb309 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_model.py
@@ -24,7 +24,6 @@ To run this file, do:
"""
import torch
-
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index fac3706d2..436ec53b4 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
index d8fa08372..b35e56abc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -902,9 +903,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
index 01ba7b711..e2f08abc6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
@@ -118,8 +118,8 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
-from train import add_model_arguments, get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
index a902358ae..2faec7ade 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
@@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
-from scaling import (
- ActivationBalancer,
- ScaledConv1d,
-)
+from scaling import ActivationBalancer, ScaledConv1d
class LConv(nn.Module):
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
index 0ff110370..3a16985bc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
@@ -52,7 +52,7 @@ import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
-from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
+from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from icefall.utils import make_pad_mask
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
index 247da0949..07e97bbdb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
@@ -14,6 +14,7 @@
import torch
from torch import nn
+
from icefall.utils import make_pad_mask
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
index 25a1aa674..c2d877a93 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
@@ -77,6 +77,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -891,9 +892,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 9a6d2155b..8e239e322 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -880,9 +881,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
index 59a7eb589..67041012d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
@@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@@ -87,7 +87,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
@@ -113,7 +113,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
index 442a0a8af..451c35332 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
@@ -4,7 +4,6 @@
import ncnn
import numpy as np
-
layer_list = []
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
index 999f7e0b4..06127607d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
@@ -42,7 +42,6 @@ import ncnn
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
-
from ncnn_custom_layer import RegisterCustomLayers
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
index e1bdce49d..8bd00bbef 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -879,9 +880,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
index 1642ef4b7..da5e144c9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
@@ -84,6 +84,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
index 3f271c5b4..646f30ca1 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index c500eb3e5..814390ad6 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Piotr Żelasko
+# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@@ -311,6 +311,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
@@ -473,3 +475,18 @@ class LibriSpeechAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
+
+ @lru_cache()
+ def gigaspeech_subset_small_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech subset-S cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
+
+ @lru_cache()
+ def gigaspeech_dev_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech dev cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
+
+ @lru_cache()
+ def gigaspeech_test_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
index 3acd22ae4..84bd3fc4b 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
@@ -304,6 +304,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
index 6c2bf9ea1..cc4471e2b 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
@@ -1,10 +1,11 @@
import argparse
import logging
import math
+import pprint
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
-import pprint
+
import k2
import sentencepiece as spm
import torch
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
index 8920764cd..1bfd071de 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
@@ -66,6 +66,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon
from icefall.utils import (
@@ -883,9 +884,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py
index 4db50b981..1f0f9bfac 100755
--- a/egs/librispeech/ASR/zipformer/ctc_decode.py
+++ b/egs/librispeech/ASR/zipformer/ctc_decode.py
@@ -88,7 +88,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
new file mode 100755
index 000000000..3cda337c0
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
@@ -0,0 +1,1114 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
index 3c0f74005..1eba6093b 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -32,7 +32,7 @@ This script exports a CTC model from PyTorch to ONNX.
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
- --left-context-frames 64 \
+ --left-context-frames 128 \
--use-ctc 1
The --chunk-size in training is "16,32,64,-1", so we select one of them
@@ -41,7 +41,7 @@ whose value is "64,128,256,-1".
It will generate the following file inside $repo/exp:
- - ctc-epoch-99-avg-1-chunk-16-left-64.onnx
+ - ctc-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models.
"""
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
index 6bc9b1858..5d0c9ea43 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
@@ -48,7 +48,7 @@ popd
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
- --left-context-frames 64
+ --left-context-frames 128
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
@@ -56,9 +56,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- - encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- - decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- - joiner-epoch-99-avg-1-chunk-16-left-64.onnx
+ - encoder-epoch-99-avg-1-chunk-16-left-128.onnx
+ - decoder-epoch-99-avg-1-chunk-16-left-128.onnx
+ - joiner-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
@@ -333,6 +333,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
+ feature_dim: int = 80,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
@@ -343,7 +344,7 @@ def export_encoder_model_onnx(
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
- x = torch.rand(1, T, 80, dtype=torch.float32)
+ x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
@@ -724,6 +725,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
+ feature_dim=params.feature_dim,
)
logging.info(f"Exported encoder to {encoder_filename}")
diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py
new file mode 100755
index 000000000..2f7ec0c17
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/finetune.py
@@ -0,0 +1,1520 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# Fine-tune without mux (i.e not mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 0 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+# Fine-tune without mux (i.e mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 1 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ # Note that we add a very large constant here to make the ScheduledFloat
+ # variable as their end value.
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ ) + 100000
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--do-finetune",
+ type=str2bool,
+ default=True,
+ help="If true, finetune from a pre-trained checkpoint",
+ )
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune. This is useful
+ if you want to maintain the performance on the original domain
+ """,
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.0045,
+ help="""The base learning rate.
+ It is set to a very small value as we are doing fine-tuning""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000.0,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100.0,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+ init_modules (list[str]): List of modules to be initialized
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dls: torch.utils.data.DataLoader,
+ valid_sets: List[str],
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ for valid_set, valid_dl in zip(valid_sets, valid_dls):
+ logging.info(f"Computing validation loss on {valid_set}")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules
+ )
+ # Need to update the model_avg if use initialisation
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ else:
+ # resuming training
+ assert params.start_epoch > 1, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
+ if params.use_mux:
+ librispeech_cuts = librispeech.train_all_shuf_cuts()
+ train_cuts = CutSet.mux(
+ gigaspeech_cuts, # num cuts = 688182
+ librispeech_cuts, # num cuts = 843723
+ weights=[688182, 843723],
+ stop_early=True,
+ )
+ else:
+ train_cuts = gigaspeech_cuts
+ logging.info(train_cuts)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+
+ valid_sets = ["librispeech", "gigaspeech"]
+ valid_dls = [
+ librispeech.valid_dataloaders(valid_cuts),
+ librispeech.valid_dataloaders(gigaspeech_dev_cuts),
+ ]
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dls=valid_dls,
+ valid_sets=valid_sets,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py
index f2f86af47..86da3ab29 100644
--- a/egs/librispeech/ASR/zipformer/model.py
+++ b/egs/librispeech/ASR/zipformer/model.py
@@ -22,9 +22,9 @@ import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
+from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask
-from scaling import ScaledLinear
class AsrModel(nn.Module):
@@ -164,9 +164,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
- targets=targets,
- input_lengths=encoder_out_lens,
- target_lengths=target_lengths,
+ targets=targets.cpu(),
+ input_lengths=encoder_out_lens.cpu(),
+ target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
diff --git a/egs/librispeech/ASR/zipformer/my_profile.py b/egs/librispeech/ASR/zipformer/my_profile.py
index ca20956fb..7e1fd777a 100755
--- a/egs/librispeech/ASR/zipformer/my_profile.py
+++ b/egs/librispeech/ASR/zipformer/my_profile.py
@@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py
import argparse
import logging
+from typing import Tuple
+
import sentencepiece as spm
import torch
-
-from typing import Tuple
-from torch import Tensor, nn
-
-from icefall.utils import make_pad_mask
-from icefall.profiler import get_model_profile
from scaling import BiasNorm
+from torch import Tensor, nn
from train import (
+ add_model_arguments,
get_encoder_embed,
get_encoder_model,
get_joiner_model,
- add_model_arguments,
get_params,
)
from zipformer import BypassModule
+from icefall.profiler import get_model_profile
+from icefall.utils import make_pad_mask
+
def get_parser():
parser = argparse.ArgumentParser(
diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py
index 356c2a830..449294444 100755
--- a/egs/librispeech/ASR/zipformer/onnx_decode.py
+++ b/egs/librispeech/ASR/zipformer/onnx_decode.py
@@ -77,11 +77,10 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
-
-from onnx_pretrained import greedy_search, OnnxModel
+from k2 import SymbolTable
+from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
-from k2 import SymbolTable
def get_parser():
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
index a77c3bf2a..114490599 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
@@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse
import logging
import math
-from typing import List, Tuple
+from typing import Dict, List, Tuple
import k2
import kaldifeat
-from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
index 6ef944514..f7d3e5253 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
@@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse
import logging
import math
-from typing import List, Tuple
+from typing import Dict, List, Tuple
import k2
import kaldifeat
-from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
index ccb3107ea..ebd385364 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
@@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse
import logging
import math
-from typing import List, Tuple
+from typing import Dict, List, Tuple
import k2
import kaldifeat
-from typing import Dict
import kaldifst
import onnxruntime as ort
import torch
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
new file mode 100755
index 000000000..a8b08de34
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
@@ -0,0 +1,439 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script loads ONNX models exported by ./export-onnx-streaming-ctc.py
+and uses them to decode waves.
+
+We use the pre-trained model from
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "exp-ctc-rnnt-small/*.pt"
+git lfs pull --include "data/lang_bpe_500/words.txt"
+git lfs pull --include "data/lang_bpe_500/HLG.fst"
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 3 \
+ --exp-dir $repo/exp-ctc-rnnt-small \
+ --causal 1 \
+ --use-ctc 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192
+
+It will generate the following 2 files inside $repo/exp-ctc-rnnt-small:
+
+ - ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx
+ - ctc-epoch-30-avg-3-chunk-16-left-128.onnx
+
+You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
+
+3. Run this file with the exported ONNX models
+
+python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
+ --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
+ --words $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.fst \
+ $repo/test_wavs/0.wav
+
+Note: Even though this script only supports decoding a single file,
+the exported ONNX models do support batch processing.
+
+Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py
+"""
+
+import argparse
+import logging
+from typing import Dict, List, Tuple
+
+import k2
+import kaldifst
+import numpy as np
+import onnxruntime as ort
+import torch
+import torchaudio
+from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model",
+ type=str,
+ required=True,
+ help="Path to the onnx model. ",
+ )
+
+ parser.add_argument(
+ "--words",
+ type=str,
+ required=True,
+ help="""Path to words.txt.""",
+ )
+
+ parser.add_argument(
+ "--HLG",
+ type=str,
+ required=True,
+ help="""Path to HLG.fst.""",
+ )
+
+ parser.add_argument(
+ "sound_file",
+ type=str,
+ help="The input sound file to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. ",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ model_filename: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 1
+
+ self.session_opts = session_opts
+
+ self.init_model(model_filename)
+
+ def init_model(self, model_filename: str):
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ self.init_states()
+
+ def init_states(self, batch_size: int = 1):
+ meta = self.model.get_modelmeta().custom_metadata_map
+ logging.info(f"meta={meta}")
+
+ model_type = meta["model_type"]
+ assert model_type == "zipformer2", model_type
+
+ decode_chunk_len = int(meta["decode_chunk_len"])
+ T = int(meta["T"])
+
+ num_encoder_layers = meta["num_encoder_layers"]
+ encoder_dims = meta["encoder_dims"]
+ cnn_module_kernels = meta["cnn_module_kernels"]
+ left_context_len = meta["left_context_len"]
+ query_head_dims = meta["query_head_dims"]
+ value_head_dims = meta["value_head_dims"]
+ num_heads = meta["num_heads"]
+
+ def to_int_list(s):
+ return list(map(int, s.split(",")))
+
+ num_encoder_layers = to_int_list(num_encoder_layers)
+ encoder_dims = to_int_list(encoder_dims)
+ cnn_module_kernels = to_int_list(cnn_module_kernels)
+ left_context_len = to_int_list(left_context_len)
+ query_head_dims = to_int_list(query_head_dims)
+ value_head_dims = to_int_list(value_head_dims)
+ num_heads = to_int_list(num_heads)
+
+ logging.info(f"decode_chunk_len: {decode_chunk_len}")
+ logging.info(f"T: {T}")
+ logging.info(f"num_encoder_layers: {num_encoder_layers}")
+ logging.info(f"encoder_dims: {encoder_dims}")
+ logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
+ logging.info(f"left_context_len: {left_context_len}")
+ logging.info(f"query_head_dims: {query_head_dims}")
+ logging.info(f"value_head_dims: {value_head_dims}")
+ logging.info(f"num_heads: {num_heads}")
+
+ num_encoders = len(num_encoder_layers)
+
+ self.states = []
+ for i in range(num_encoders):
+ num_layers = num_encoder_layers[i]
+ key_dim = query_head_dims[i] * num_heads[i]
+ embed_dim = encoder_dims[i]
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ value_dim = value_head_dims[i] * num_heads[i]
+ conv_left_pad = cnn_module_kernels[i] // 2
+
+ for layer in range(num_layers):
+ cached_key = torch.zeros(
+ left_context_len[i], batch_size, key_dim
+ ).numpy()
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, left_context_len[i], nonlin_attn_head_dim
+ ).numpy()
+ cached_val1 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_val2 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ self.states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+ embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
+ self.states.append(embed_states)
+ processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
+ self.states.append(processed_lens)
+
+ self.num_encoders = num_encoders
+
+ self.segment = T
+ self.offset = decode_chunk_len
+
+ def _build_model_input_output(
+ self,
+ x: torch.Tensor,
+ ) -> Tuple[Dict[str, np.ndarray], List[str]]:
+ model_input = {"x": x.numpy()}
+ model_output = ["log_probs"]
+
+ def build_inputs_outputs(tensors, i):
+ assert len(tensors) == 6, len(tensors)
+
+ # (downsample_left, batch_size, key_dim)
+ name = f"cached_key_{i}"
+ model_input[name] = tensors[0]
+ model_output.append(f"new_{name}")
+
+ # (1, batch_size, downsample_left, nonlin_attn_head_dim)
+ name = f"cached_nonlin_attn_{i}"
+ model_input[name] = tensors[1]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val1_{i}"
+ model_input[name] = tensors[2]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val2_{i}"
+ model_input[name] = tensors[3]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv1_{i}"
+ model_input[name] = tensors[4]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv2_{i}"
+ model_input[name] = tensors[5]
+ model_output.append(f"new_{name}")
+
+ for i in range(len(self.states[:-2]) // 6):
+ build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
+
+ # (batch_size, channels, left_pad, freq)
+ name = "embed_states"
+ embed_states = self.states[-2]
+ model_input[name] = embed_states
+ model_output.append(f"new_{name}")
+
+ # (batch_size,)
+ name = "processed_lens"
+ processed_lens = self.states[-1]
+ model_input[name] = processed_lens
+ model_output.append(f"new_{name}")
+
+ return model_input, model_output
+
+ def _update_states(self, states: List[np.ndarray]):
+ self.states = states
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ Returns:
+ Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
+ where T' is usually equal to ((T-7)//2 - 3)//2
+ """
+ model_input, model_output_names = self._build_model_input_output(x)
+
+ out = self.model.run(model_output_names, model_input)
+
+ self._update_states(out[1:])
+
+ return torch.from_numpy(out[0])
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ if sample_rate != expected_sample_rate:
+ logging.info(f"Resample {sample_rate} to {expected_sample_rate}")
+ wave = torchaudio.functional.resample(
+ wave,
+ orig_freq=sample_rate,
+ new_freq=expected_sample_rate,
+ )
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+def create_streaming_feature_extractor() -> OnlineFeature:
+ """Create a CPU streaming feature extractor.
+
+ At present, we assume it returns a fbank feature extractor with
+ fixed options. In the future, we will support passing in the options
+ from outside.
+
+ Returns:
+ Return a CPU streaming feature extractor.
+ """
+ opts = FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+ return OnlineFbank(opts)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ word_table = k2.SymbolTable.from_file(args.words)
+ model = OnnxModel(model_filename=args.nn_model)
+
+ sample_rate = 16000
+
+ logging.info("Constructing Fbank computer")
+ online_fbank = create_streaming_feature_extractor()
+
+ logging.info(f"Reading sound files: {args.sound_file}")
+ waves = read_sound_files(
+ filenames=[args.sound_file],
+ expected_sample_rate=sample_rate,
+ )[0]
+
+ tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
+ wave_samples = torch.cat([waves, tail_padding])
+
+ num_processed_frames = 0
+ segment = model.segment
+ offset = model.offset
+
+ logging.info(f"Loading HLG from {args.HLG}")
+ HLG = kaldifst.StdVectorFst.read(args.HLG)
+
+ decoder_opts = FasterDecoderOptions(max_active=3000)
+ decoder = FasterDecoder(HLG, decoder_opts)
+ decoder.init_decoding()
+
+ chunk = int(1 * sample_rate) # 1 second
+ start = 0
+
+ n = 0
+ while start < wave_samples.numel():
+ end = min(start + chunk, wave_samples.numel())
+
+ # simulate streaming
+ samples = wave_samples[start:end]
+ start += chunk
+
+ online_fbank.accept_waveform(
+ sampling_rate=sample_rate,
+ waveform=samples,
+ )
+
+ while online_fbank.num_frames_ready - num_processed_frames >= segment:
+ frames = []
+ for i in range(segment):
+ frames.append(online_fbank.get_frame(num_processed_frames + i))
+
+ frames = torch.cat(frames, dim=0)
+ frames = frames.unsqueeze(0)
+
+ log_probs = model(frames)
+ log_probs = log_probs.squeeze(0).cpu().numpy()
+
+ decodable = DecodableCtc(log_probs, offset=n)
+ n += log_probs.shape[0]
+
+ num_processed_frames += offset
+ decoder.advance_decoding(decodable)
+
+ if not decoder.reached_final():
+ logging.info(f"Failed to decode {args.sound_file}")
+ return
+
+ ok, best_path = decoder.get_best_path()
+
+ (
+ ok,
+ isymbols_out,
+ osymbols_out,
+ total_weight,
+ ) = kaldifst.get_linear_symbol_sequence(best_path)
+
+ if not ok:
+ logging.info(f"Failed to get linear symbol sequence for {args.sound_file}")
+ return
+
+ hyps = " ".join([word_table[i] for i in osymbols_out]).lower()
+ logging.info(f"\n{args.sound_file}\n{hyps}")
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py
index 714d8db9a..aaffbfed5 100644
--- a/egs/librispeech/ASR/zipformer/optim.py
+++ b/egs/librispeech/ASR/zipformer/optim.py
@@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer):
# case 2 or case 4
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
- assert "named_params" in cur_group
- name_list = [x[0] for x in cur_group["named_params"]]
- p_list = [x[1] for x in cur_group["named_params"]]
- del cur_group["named_params"]
- cur_group["params"] = p_list
+ if "named_params" in cur_group:
+ name_list = [x[0] for x in cur_group["named_params"]]
+ p_list = [x[1] for x in cur_group["named_params"]]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ else:
+ assert "params" in cur_group
+ name_list = ["foo" for _ in cur_group["params"]]
param_groups.append(cur_group)
param_groups_names.append(name_list)
diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py
index c0f1e3087..29ac33c02 100644
--- a/egs/librispeech/ASR/zipformer/scaling.py
+++ b/egs/librispeech/ASR/zipformer/scaling.py
@@ -15,15 +15,16 @@
# limitations under the License.
-from typing import Optional, Tuple, Union
import logging
-import k2
-from torch.cuda.amp import custom_fwd, custom_bwd
-import random
-import torch
import math
+import random
+from typing import Optional, Tuple, Union
+
+import k2
+import torch
import torch.nn as nn
from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py
index 8087c1460..360523b8e 100755
--- a/egs/librispeech/ASR/zipformer/streaming_decode.py
+++ b/egs/librispeech/ASR/zipformer/streaming_decode.py
@@ -51,7 +51,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
-from train import add_model_arguments, get_params, get_model
+from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py
index d16d87bac..b2f769d3f 100644
--- a/egs/librispeech/ASR/zipformer/subsampling.py
+++ b/egs/librispeech/ASR/zipformer/subsampling.py
@@ -16,11 +16,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
import warnings
+from typing import Tuple
import torch
-from torch import Tensor, nn
from scaling import (
Balancer,
BiasNorm,
@@ -34,6 +33,7 @@ from scaling import (
SwooshR,
Whiten,
)
+from torch import Tensor, nn
class ConvNeXt(nn.Module):
diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py
index 3ccf7d2f1..04caf2fd8 100755
--- a/egs/librispeech/ASR/zipformer/train.py
+++ b/egs/librispeech/ASR/zipformer/train.py
@@ -90,6 +90,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -965,7 +966,10 @@ def train_one_epoch(
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
- except: # noqa
+ except Exception as e:
+ logging.info(
+ f"Caught exception: {e}."
+ )
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
@@ -1021,9 +1025,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py
index 61ae378d8..17a3f8719 100644
--- a/egs/librispeech/ASR/zipformer/zipformer.py
+++ b/egs/librispeech/ASR/zipformer/zipformer.py
@@ -788,7 +788,7 @@ class Zipformer2EncoderLayer(nn.Module):
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
- elif not self.training and random.random() < float(self.const_attention_rate):
+ elif self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
diff --git a/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py b/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/beam_search.py b/egs/librispeech/ASR/zipformer_adapter/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py
new file mode 100755
index 000000000..91533be8d
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/decode.py
@@ -0,0 +1,1070 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+- To activate the adapter (test on the target domain)
+set --use-adapter True
+
+- To deactivate the adapter (test on the original domain)
+set --use-adapter False
+
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ import pdb
+
+ pdb.set_trace()
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ elif params.avg == 1:
+ load_checkpoint(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
+ )
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ ),
+ strict=False,
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+ test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
new file mode 100755
index 000000000..bbc582f50
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
@@ -0,0 +1,1115 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from train import add_finetune_arguments, add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / (params.decoding_method + "_giga")
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/decoder.py b/egs/librispeech/ASR/zipformer_adapter/decoder.py
new file mode 120000
index 000000000..cab465d2b
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/decoder.py
@@ -0,0 +1 @@
+../zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py b/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py
new file mode 120000
index 000000000..aa5d0217a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/encoder_interface.py
@@ -0,0 +1 @@
+../transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/export-onnx.py b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py
new file mode 100755
index 000000000..062396168
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/export-onnx.py
@@ -0,0 +1,623 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script exports a transducer model from PyTorch to ONNX.
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer_adapter/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --use-adapters 1 \
+ --adapter-dim 32 \
+ --exp-dir $repo/exp \
+ --num-encoder-layers "2,2,3,4,3,2" \
+ --downsampling-factor "1,2,4,8,4,2" \
+ --feedforward-dim "512,768,1024,1536,1024,768" \
+ --num-heads "4,4,4,8,4,4" \
+ --encoder-dim "192,256,384,512,384,256" \
+ --query-head-dim 32 \
+ --value-head-dim 12 \
+ --pos-head-dim 4 \
+ --pos-dim 48 \
+ --encoder-unmasked-dim "192,192,256,256,256,192" \
+ --cnn-module-kernel "31,31,15,15,15,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --causal False \
+ --chunk-size "16,32,64,-1" \
+ --left-context-frames "64,128,256,-1"
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-99-avg-1.onnx
+ - decoder-epoch-99-avg-1.onnx
+ - joiner-epoch-99-avg-1.onnx
+
+See ./onnx_pretrained.py and ./onnx_check.py for how to
+use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import k2
+import onnx
+import torch
+import torch.nn as nn
+from decoder import Decoder
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_finetune_arguments, add_model_arguments, get_model, get_params
+from zipformer import Zipformer2
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_adapter/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxEncoder(nn.Module):
+ """A wrapper for Zipformer and the encoder_proj from the joiner"""
+
+ def __init__(
+ self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
+ ):
+ """
+ Args:
+ encoder:
+ A Zipformer encoder.
+ encoder_proj:
+ The projection layer for encoder from the joiner.
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+ self.encoder_proj = encoder_proj
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of Zipformer.forward
+
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ x_lens:
+ A 1-D tensor of shape (N,). Its dtype is torch.int64
+ Returns:
+ Return a tuple containing:
+ - encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
+ - encoder_out_lens, A 1-D tensor of shape (N,)
+ """
+ x, x_lens = self.encoder_embed(x, x_lens)
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2)
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2)
+ encoder_out = self.encoder_proj(encoder_out)
+ # Now encoder_out is of shape (N, T, joiner_dim)
+
+ return encoder_out, encoder_out_lens
+
+
+class OnnxDecoder(nn.Module):
+ """A wrapper for Decoder and the decoder_proj from the joiner"""
+
+ def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
+ super().__init__()
+ self.decoder = decoder
+ self.decoder_proj = decoder_proj
+
+ def forward(self, y: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ y:
+ A 2-D tensor of shape (N, context_size).
+ Returns
+ Return a 2-D tensor of shape (N, joiner_dim)
+ """
+ need_pad = False
+ decoder_output = self.decoder(y, need_pad=need_pad)
+ decoder_output = decoder_output.squeeze(1)
+ output = self.decoder_proj(decoder_output)
+
+ return output
+
+
+class OnnxJoiner(nn.Module):
+ """A wrapper for the joiner"""
+
+ def __init__(self, output_linear: nn.Linear):
+ super().__init__()
+ self.output_linear = output_linear
+
+ def forward(
+ self,
+ encoder_out: torch.Tensor,
+ decoder_out: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ encoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ decoder_out:
+ A 2-D tensor of shape (N, joiner_dim)
+ Returns:
+ Return a 2-D tensor of shape (N, vocab_size)
+ """
+ logit = encoder_out + decoder_out
+ logit = self.output_linear(torch.tanh(logit))
+ return logit
+
+
+def export_encoder_model_onnx(
+ encoder_model: OnnxEncoder,
+ encoder_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given encoder model to ONNX format.
+ The exported model has two inputs:
+
+ - x, a tensor of shape (N, T, C); dtype is torch.float32
+ - x_lens, a tensor of shape (N,); dtype is torch.int64
+
+ and it has two outputs:
+
+ - encoder_out, a tensor of shape (N, T', joiner_dim)
+ - encoder_out_lens, a tensor of shape (N,)
+
+ Args:
+ encoder_model:
+ The input encoder model
+ encoder_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ x = torch.zeros(1, 100, 80, dtype=torch.float32)
+ x_lens = torch.tensor([100], dtype=torch.int64)
+
+ encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
+
+ torch.onnx.export(
+ encoder_model,
+ (x, x_lens),
+ encoder_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["x", "x_lens"],
+ output_names=["encoder_out", "encoder_out_lens"],
+ dynamic_axes={
+ "x": {0: "N", 1: "T"},
+ "x_lens": {0: "N"},
+ "encoder_out": {0: "N", 1: "T"},
+ "encoder_out_lens": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "zipformer2",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "non-streaming zipformer2",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=encoder_filename, meta_data=meta_data)
+
+
+def export_decoder_model_onnx(
+ decoder_model: OnnxDecoder,
+ decoder_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the decoder model to ONNX format.
+
+ The exported model has one input:
+
+ - y: a torch.int64 tensor of shape (N, decoder_model.context_size)
+
+ and has one output:
+
+ - decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
+
+ Args:
+ decoder_model:
+ The decoder model to be exported.
+ decoder_filename:
+ Filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ context_size = decoder_model.decoder.context_size
+ vocab_size = decoder_model.decoder.vocab_size
+
+ y = torch.zeros(10, context_size, dtype=torch.int64)
+ decoder_model = torch.jit.script(decoder_model)
+ torch.onnx.export(
+ decoder_model,
+ y,
+ decoder_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["y"],
+ output_names=["decoder_out"],
+ dynamic_axes={
+ "y": {0: "N"},
+ "decoder_out": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "context_size": str(context_size),
+ "vocab_size": str(vocab_size),
+ }
+ add_meta_data(filename=decoder_filename, meta_data=meta_data)
+
+
+def export_joiner_model_onnx(
+ joiner_model: nn.Module,
+ joiner_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the joiner model to ONNX format.
+ The exported joiner model has two inputs:
+
+ - encoder_out: a tensor of shape (N, joiner_dim)
+ - decoder_out: a tensor of shape (N, joiner_dim)
+
+ and produces one output:
+
+ - logit: a tensor of shape (N, vocab_size)
+ """
+ joiner_dim = joiner_model.output_linear.weight.shape[1]
+ logging.info(f"joiner dim: {joiner_dim}")
+
+ projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+ projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+
+ torch.onnx.export(
+ joiner_model,
+ (projected_encoder_out, projected_decoder_out),
+ joiner_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=[
+ "encoder_out",
+ "decoder_out",
+ ],
+ output_names=["logit"],
+ dynamic_axes={
+ "encoder_out": {0: "N"},
+ "decoder_out": {0: "N"},
+ "logit": {0: "N"},
+ },
+ )
+ meta_data = {
+ "joiner_dim": str(joiner_dim),
+ }
+ add_meta_data(filename=joiner_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
+
+ encoder = OnnxEncoder(
+ encoder=model.encoder,
+ encoder_embed=model.encoder_embed,
+ encoder_proj=model.joiner.encoder_proj,
+ )
+
+ decoder = OnnxDecoder(
+ decoder=model.decoder,
+ decoder_proj=model.joiner.decoder_proj,
+ )
+
+ joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
+
+ encoder_num_param = sum([p.numel() for p in encoder.parameters()])
+ decoder_num_param = sum([p.numel() for p in decoder.parameters()])
+ joiner_num_param = sum([p.numel() for p in joiner.parameters()])
+ total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
+ logging.info(f"encoder parameters: {encoder_num_param}")
+ logging.info(f"decoder parameters: {decoder_num_param}")
+ logging.info(f"joiner parameters: {joiner_num_param}")
+ logging.info(f"total parameters: {total_num_param}")
+
+ if params.iter > 0:
+ suffix = f"iter-{params.iter}"
+ else:
+ suffix = f"epoch-{params.epoch}"
+
+ suffix += f"-avg-{params.avg}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
+ export_encoder_model_onnx(
+ encoder,
+ encoder_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported encoder to {encoder_filename}")
+
+ logging.info("Exporting decoder")
+ decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
+ export_decoder_model_onnx(
+ decoder,
+ decoder_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported decoder to {decoder_filename}")
+
+ logging.info("Exporting joiner")
+ joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
+ export_joiner_model_onnx(
+ joiner,
+ joiner_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported joiner to {joiner_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=encoder_filename,
+ model_output=encoder_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=decoder_filename,
+ model_output=decoder_filename_int8,
+ op_types_to_quantize=["MatMul", "Gather"],
+ weight_type=QuantType.QInt8,
+ )
+
+ joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=joiner_filename,
+ model_output=joiner_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/export.py b/egs/librispeech/ASR/zipformer_adapter/export.py
new file mode 100755
index 000000000..72dfc081b
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/export.py
@@ -0,0 +1,520 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+- For non-streaming model:
+
+./zipformer_adapter/export.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+- For streaming model:
+
+./zipformer_adapter/export.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
+You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
+
+Check ./jit_pretrained_streaming.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+- For non-streaming model:
+
+./zipformer_adapter/export.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --avg 9
+
+- For streaming model:
+
+./zipformer_adapter/export.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --causal 1 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --epoch 30 \
+ --avg 9
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+- For non-streaming model:
+
+To use the generated file with `zipformer_adapter/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./zipformer_adapter/decode_gigaspeech.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+- For streaming model:
+
+To use the generated file with `zipformer_adapter/decode.py` and `zipformer_adapter/streaming_decode.py`, you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+
+ # simulated streaming decoding
+ ./zipformer_adapter/decode_gigaspeech.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+ # chunk-wise streaming decoding
+ ./zipformer_adapter/streaming_decode.py \
+ --exp-dir ./zipformer_adapter/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+from train import add_finetune_arguments, add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_adapter/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class StreamingEncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ assert len(encoder.chunk_size) == 1, encoder.chunk_size
+ assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ self.pad_length = 7 + 2 * 3
+
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """Streaming forward for encoder_embed and encoder.
+
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ states: a list of Tensors
+
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ chunk_size = self.chunk_size
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lengths,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ # Wrap encoder and encoder_embed as a module
+ if params.causal:
+ model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
+ chunk_size = model.encoder.chunk_size
+ left_context_len = model.encoder.left_context_len
+ filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
+ else:
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/joiner.py b/egs/librispeech/ASR/zipformer_adapter/joiner.py
new file mode 120000
index 000000000..444cb5f15
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/joiner.py
@@ -0,0 +1 @@
+../zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/model.py b/egs/librispeech/ASR/zipformer_adapter/model.py
new file mode 120000
index 000000000..0c6fe6112
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/model.py
@@ -0,0 +1 @@
+../zipformer/model.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py
new file mode 100755
index 000000000..e3f7ce85a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/onnx_decode.py
@@ -0,0 +1,386 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads ONNX exported models and uses them to decode the test sets.
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "data/lang_bpe_500/bpe.model"
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --causal False
+
+It will generate the following 3 files inside $repo/exp:
+
+ - encoder-epoch-99-avg-1.onnx
+ - decoder-epoch-99-avg-1.onnx
+ - joiner-epoch-99-avg-1.onnx
+
+2. Run this file
+
+./zipformer/onnx_decode.py \
+ --exp-dir $repo/exp \
+ --max-duration 600 \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+"""
+
+
+import argparse
+import logging
+import time
+from pathlib import Path
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from k2 import SymbolTable
+from onnx_pretrained import OnnxModel, greedy_search
+
+from icefall.utils import setup_logger, store_transcripts, write_error_stats
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--encoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the encoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--decoder-model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--joiner-model-filename",
+ type=str,
+ required=True,
+ help="Path to the joiner onnx model. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="""Path to tokens.txt.""",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="Valid values are greedy_search and modified_beam_search",
+ )
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ model: OnnxModel, token_table: SymbolTable, batch: dict
+) -> List[List[str]]:
+ """Decode one batch and return the result.
+ Currently it only greedy_search is supported.
+
+ Args:
+ model:
+ The neural model.
+ token_table:
+ The token table.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
+
+ encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
+
+ hyps = greedy_search(
+ model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
+ )
+
+ def token_ids_to_words(token_ids: List[int]) -> str:
+ text = ""
+ for i in token_ids:
+ text += token_table[i]
+ return text.replace("▁", " ").strip()
+
+ hyps = [token_ids_to_words(h).split() for h in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ model: nn.Module,
+ token_table: SymbolTable,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ model:
+ The neural model.
+ token_table:
+ The token table.
+
+ Returns:
+ - A list of tuples. Each tuple contains three elements:
+ - cut_id,
+ - reference transcript,
+ - predicted result.
+ - The total duration (in seconds) of the dataset.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 10
+ total_duration = 0
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+ total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
+
+ hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+
+ return results, total_duration
+
+
+def save_results(
+ res_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+):
+ recog_path = res_dir / f"recogs-{test_set_name}.txt"
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = res_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("WER", file=f)
+ print(wer, file=f)
+
+ s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+
+ assert (
+ args.decoding_method == "greedy_search"
+ ), "Only supports greedy_search currently."
+ res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
+
+ setup_logger(f"{res_dir}/log-decode")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+
+ token_table = SymbolTable.from_file(args.tokens)
+
+ logging.info(vars(args))
+
+ logging.info("About to create model")
+ model = OnnxModel(
+ encoder_model_filename=args.encoder_model_filename,
+ decoder_model_filename=args.decoder_model_filename,
+ joiner_model_filename=args.joiner_model_filename,
+ )
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ start_time = time.time()
+ results, total_duration = decode_dataset(
+ dl=test_dl, model=model, token_table=token_table
+ )
+ end_time = time.time()
+ elapsed_seconds = end_time - start_time
+ rtf = elapsed_seconds / total_duration
+
+ logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
+ logging.info(f"Wave duration: {total_duration:.3f} s")
+ logging.info(
+ f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
+ )
+
+ save_results(res_dir=res_dir, test_set_name=test_set, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py b/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py
new file mode 120000
index 000000000..a085def83
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/onnx_pretrained.py
@@ -0,0 +1 @@
+../zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/optim.py b/egs/librispeech/ASR/zipformer_adapter/optim.py
new file mode 120000
index 000000000..207eecfcd
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/optim.py
@@ -0,0 +1 @@
+../zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/scaling.py b/egs/librispeech/ASR/zipformer_adapter/scaling.py
new file mode 120000
index 000000000..58e4b0a0f
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/scaling.py
@@ -0,0 +1 @@
+../zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py b/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py
new file mode 120000
index 000000000..bc7c7b5e3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/scaling_converter.py
@@ -0,0 +1 @@
+../zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/subsampling.py b/egs/librispeech/ASR/zipformer_adapter/subsampling.py
new file mode 120000
index 000000000..d178adc2e
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/subsampling.py
@@ -0,0 +1 @@
+../zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py
new file mode 100755
index 000000000..6c55896a8
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/train.py
@@ -0,0 +1,1544 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# Finetune non-streaming model using adapters:
+
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --use-mux 0 \
+ --use-adapters 1 \
+ --adapter-dim 16 \
+ --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ ) + 1000000
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--do-finetune",
+ type=str2bool,
+ default=True,
+ help="If true, finetune from a pre-trained checkpoint",
+ )
+
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune. This is useful
+ if you want to maintain the performance on the original domain
+ """,
+ )
+
+ parser.add_argument(
+ "--use-adapters",
+ type=str2bool,
+ default=True,
+ help="If use adapter to finetune the model",
+ )
+
+ parser.add_argument(
+ "--adapter-dim",
+ type=int,
+ default=16,
+ help="The bottleneck dimension of the adapter",
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ use_adapters=params.use_adapters,
+ adapter_dim=params.adapter_dim,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+ init_modules (list[str]): List of modules to be initialized
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dls: torch.utils.data.DataLoader,
+ valid_sets: List[str],
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ # set modules except adapters to eval mode
+ for name, m in model.named_modules():
+ if "adapter" in name:
+ m.training = True
+ else:
+ m.training = False
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ for valid_set, valid_dl in zip(valid_sets, valid_dls):
+ logging.info(f"Computing validation loss on {valid_set}")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ logging.info(
+ f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
+ )
+ model.train()
+ # set modules except adapters to eval mode
+ for name, m in model.named_modules():
+ if "adapter" in name:
+ m.training = True
+ else:
+ m.training = False
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
+ )
+ # Need to update the model_avg if use initialisation
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ else:
+ # resuming training
+ assert params.start_epoch > 1, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ # keep the original model untouched, only update the adapters
+ num_trainable = 0
+ for name, p in model.named_parameters():
+ if "adapter" in name:
+ p.requires_grad = True
+ num_trainable += p.numel()
+ else:
+ p.requires_grad = False
+
+ logging.info(
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
+ num_trainable, num_trainable / num_param * 100
+ )
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
+ if params.use_mux:
+ librispeech_cuts = librispeech.train_all_shuf_cuts()
+ train_cuts = CutSet.mux(
+ gigaspeech_cuts, # num cuts = 688182
+ librispeech_cuts, # num cuts = 843723
+ weights=[688182, 843723],
+ stop_early=True,
+ )
+ else:
+ train_cuts = gigaspeech_cuts
+ logging.info(train_cuts)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+
+ valid_sets = ["librispeech", "gigaspeech"]
+ valid_dls = [
+ librispeech.valid_dataloaders(valid_cuts),
+ librispeech.valid_dataloaders(gigaspeech_dev_cuts),
+ ]
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dls=valid_dls,
+ valid_sets=valid_sets,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py
new file mode 100644
index 000000000..8e2dfdd72
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py
@@ -0,0 +1,2527 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from encoder_interface import EncoderInterface
+from scaling import (
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+ ActivationDropoutAndLinear,
+ Balancer,
+ BiasNorm,
+ ChunkCausalDepthwiseConv1d,
+ Dropout2,
+ FloatLike,
+ ScheduledFloat,
+ SwooshL,
+ SwooshR,
+ Whiten,
+ convert_num_channels,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+from torch import Tensor, nn
+
+
+class Zipformer2(EncoderInterface):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+ as downsampling_factor if they are single ints or one-element tuples. The length of
+ downsampling_factor defines the number of stacks.
+
+ output_downsampling_factor (int): how much to downsample at the output. Note:
+ we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
+ You should probably leave this at 2.
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+ encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
+ the encoder stacks for purposes of per-frame dropout (recommend 256 for
+ now).
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+ attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+ e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ causal (bool): if True, support chunkwise causal convolution. This should
+ not hurt WER as no modeling power is lost, but the convolution modules will be
+ slightly slower and use more memory. Enables use of the chunk_size and
+ left_context_chunks options in forward(), which simulates streaming
+ decoding.
+ chunk_size: (list of int): only set this to other than [-1] if causal;
+ the chunk size will be randomly chosen from this list. -1 means no chunking.
+ left_context_frames: (list of int): determines the number of left-
+ context chunks for causal training; will be rounded to a number of
+ chunks. Must not be less than cnn_module_kernel (after factoring in
+ rounding and downsampling); an error will be thrown if this is violated.
+ use_adapters: insert adapters in the zipformer encoder
+ adapter_dim: the dimension of the adapters
+ """
+
+ def __init__(
+ self,
+ output_downsampling_factor: int = 2,
+ downsampling_factor: Tuple[int] = (2, 4),
+ encoder_dim: Union[int, Tuple[int]] = 384,
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
+ query_head_dim: Union[int, Tuple[int]] = 24,
+ pos_head_dim: Union[int, Tuple[int]] = 4,
+ value_head_dim: Union[int, Tuple[int]] = 12,
+ num_heads: Union[int, Tuple[int]] = 8,
+ feedforward_dim: Union[int, Tuple[int]] = 1536,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ causal: bool = False,
+ chunk_size: Tuple[int] = [-1],
+ left_context_frames: Tuple[int] = [-1],
+ use_adapters: bool = False,
+ adapter_dim: int = 16,
+ ) -> None:
+ super(Zipformer2, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
+ as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ self.output_downsampling_factor = output_downsampling_factor # int
+ self.downsampling_factor = downsampling_factor # tuple
+ self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
+ self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
+ encoder_unmasked_dim
+ ) # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
+ self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
+ pos_head_dim = _to_tuple(pos_head_dim)
+ self.num_heads = num_heads = _to_tuple(num_heads)
+ feedforward_dim = _to_tuple(feedforward_dim)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+
+ self.causal = causal
+ self.chunk_size = chunk_size
+ self.left_context_frames = left_context_frames
+ self.use_adapters = use_adapters
+
+ for u, d in zip(encoder_unmasked_dim, encoder_dim):
+ assert u <= d
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim[i],
+ pos_dim=pos_dim,
+ num_heads=num_heads[i],
+ query_head_dim=query_head_dim[i],
+ pos_head_dim=pos_head_dim[i],
+ value_head_dim=value_head_dim[i],
+ feedforward_dim=feedforward_dim[i],
+ dropout=dropout,
+ cnn_module_kernel=cnn_module_kernel[i],
+ causal=causal,
+ use_adapters=use_adapters,
+ adapter_dim=adapter_dim,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ pos_dim=pos_dim,
+ dropout=dropout,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim[i],
+ downsample=downsampling_factor[i],
+ dropout=dropout,
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+
+ self.downsample_output = SimpleDownsample(
+ max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
+ )
+
+ def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
+ """
+ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
+ randomized feature masks, one per encoder.
+ On e.g. 15% of frames, these masks will zero out all enocder dims larger than
+ some supplied number, e.g. >256, so in effect on those frames we are using
+ a smaller encoer dim.
+
+ We generate the random masks at this level because we want the 2 masks to 'agree'
+ all the way up the encoder stack. This will mean that the 1st mask will have
+ mask values repeated self.zipformer_subsampling_factor times.
+
+ Args:
+ x: the embeddings (needed for the shape and dtype and device), of shape
+ (1, batch_size, encoder_dims0)
+ """
+ num_encoders = len(self.encoder_dim)
+ if not self.training:
+ return [1.0] * num_encoders
+
+ (num_frames0, batch_size, _encoder_dims0) = x.shape
+
+ assert self.encoder_dim[0] == _encoder_dims0, (
+ self.encoder_dim[0],
+ _encoder_dims0,
+ )
+
+ feature_mask_dropout_prob = 0.125
+
+ # mask1 shape: (1, batch_size, 1)
+ mask1 = (
+ torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
+ ).to(x.dtype)
+
+ # mask2 has additional sequences masked, about twice the number.
+ mask2 = torch.logical_and(
+ mask1,
+ (
+ torch.rand(1, batch_size, 1, device=x.device)
+ > feature_mask_dropout_prob
+ ).to(x.dtype),
+ )
+
+ # dim: (1, batch_size, 2)
+ mask = torch.cat((mask1, mask2), dim=-1)
+
+ feature_masks = []
+ for i in range(num_encoders):
+ channels = self.encoder_dim[i]
+ feature_mask = torch.ones(
+ 1, batch_size, channels, dtype=x.dtype, device=x.device
+ )
+ u1 = self.encoder_unmasked_dim[i]
+ u2 = u1 + (channels - u1) // 2
+
+ feature_mask[:, :, u1:u2] *= mask[..., 0:1]
+ feature_mask[:, :, u2:] *= mask[..., 1:2]
+
+ feature_masks.append(feature_mask)
+
+ return feature_masks
+
+ def get_chunk_info(self) -> Tuple[int, int]:
+ """
+ Returns chunk_size and left_context_chunks.
+ """
+ if not self.causal:
+ return -1, -1
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.chunk_size) == 1, self.chunk_size
+ chunk_size = self.chunk_size[0]
+ else:
+ chunk_size = random.choice(self.chunk_size)
+
+ if chunk_size == -1:
+ left_context_chunks = -1
+ else:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.left_context_frames) == 1, self.left_context_frames
+ left_context_frames = self.left_context_frames[0]
+ else:
+ left_context_frames = random.choice(self.left_context_frames)
+ # Note: in Python, -1 // n == -1 for n > 0
+ left_context_chunks = left_context_frames // chunk_size
+ if left_context_chunks == 0:
+ left_context_chunks = 1
+
+ return chunk_size, left_context_chunks
+
+ def forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ """
+ outputs = []
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ feature_masks = [1.0] * len(self.encoder_dim)
+ else:
+ feature_masks = self.get_feature_masks(x)
+
+ chunk_size, left_context_chunks = self.get_chunk_info()
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # Not support exporting a model for simulating streaming decoding
+ attn_mask = None
+ else:
+ attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
+
+ for i, module in enumerate(self.encoders):
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x = module(
+ x,
+ chunk_size=chunk_size,
+ feature_mask=feature_masks[i],
+ src_key_padding_mask=(
+ None
+ if src_key_padding_mask is None
+ else src_key_padding_mask[..., ::ds]
+ ),
+ attn_mask=attn_mask,
+ )
+ outputs.append(x)
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2, self.output_downsampling_factor
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths
+
+ def _get_attn_mask(
+ self, x: Tensor, chunk_size: int, left_context_chunks: int
+ ) -> Optional[Tensor]:
+ """
+ Return None if chunk_size == -1, else return attention mask of shape
+ (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
+ means a masked position.
+ Args:
+ x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
+ chunk_size: chunk size, must divide
+ """
+ if chunk_size <= 0:
+ return None
+ assert all(chunk_size % d == 0 for d in self.downsampling_factor)
+ if left_context_chunks >= 0:
+ num_encoders = len(self.encoder_dim)
+ assert all(
+ chunk_size * left_context_chunks
+ >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
+ for i in range(num_encoders)
+ )
+ else:
+ left_context_chunks = 1000000
+
+ seq_len = x.shape[0]
+
+ # t is frame index, shape (seq_len,)
+ t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
+ # c is chunk index for each frame, shape (seq_len,)
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ c = t // chunk_size
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ c = t // chunk_size
+ src_c = c
+ tgt_c = c.unsqueeze(-1)
+
+ attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
+ if __name__ == "__main__":
+ logging.info(f"attn_mask = {attn_mask}")
+ return attn_mask
+
+ def _get_full_dim_output(self, outputs: List[Tensor]):
+ num_encoders = len(self.encoder_dim)
+ assert len(outputs) == num_encoders
+ output_dim = max(self.encoder_dim)
+ output_pieces = [outputs[-1]]
+ cur_dim = self.encoder_dim[-1]
+ for i in range(num_encoders - 2, -1, -1):
+ d = self.encoder_dim[i]
+ if d > cur_dim:
+ this_output = outputs[i]
+ output_pieces.append(this_output[..., cur_dim:d])
+ cur_dim = d
+ assert cur_dim == output_dim
+ return torch.cat(output_pieces, dim=-1)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ states: List[Tensor],
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ states: list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ - updated states
+ """
+ outputs = []
+ new_states = []
+ layer_offset = 0
+
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x, new_layer_states = module.streaming_forward(
+ x,
+ states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
+ left_context_len=self.left_context_frames[0] // ds,
+ src_key_padding_mask=src_key_padding_mask[..., ::ds],
+ )
+ layer_offset += num_layers
+ outputs.append(x)
+ new_states += new_layer_states
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[Tensor]:
+ """Get initial states.
+
+ A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ """
+ states = []
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ embed_dim = self.encoder_dim[i]
+ ds = self.downsampling_factor[i]
+ num_heads = self.num_heads[i]
+ key_dim = self.query_head_dim[i] * num_heads
+ value_dim = self.value_head_dim[i] * num_heads
+ downsample_left = self.left_context_frames[0] // ds
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ conv_left_pad = self.cnn_module_kernel[i] // 2
+ for layer in range(num_layers):
+ cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
+ device
+ )
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, downsample_left, nonlin_attn_head_dim
+ ).to(device)
+ cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ return states
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+def _balancer_schedule(min_prob: float):
+ return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+ use_adapters: insert adapters in each layer
+ adapter_dim: the bottleneck dimension of the adapter
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ causal: bool = False,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ use_adapters: bool = False,
+ adapter_dim: int = 16,
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass; see its default values.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ )
+
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 3) // 4, dropout
+ )
+
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 5) // 4, dropout
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.conv_module1 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ self.conv_module2 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ # TODO: remove it
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ self.use_adapters = use_adapters
+ if use_adapters:
+ self.mid_adapter = AdapterModule(
+ embed_dim=embed_dim,
+ bottleneck_dim=adapter_dim,
+ )
+
+ # placed after the 1st self-attn module
+ self.post_sa_adapter = AdapterModule(
+ embed_dim=embed_dim,
+ bottleneck_dim=adapter_dim,
+ )
+
+ # placed after the 2nd convolution module
+ self.post_conv_adapter = AdapterModule(
+ embed_dim=embed_dim,
+ bottleneck_dim=adapter_dim,
+ )
+
+ # at the end of each layer
+ self.adapter = AdapterModule(
+ embed_dim=embed_dim,
+ bottleneck_dim=adapter_dim,
+ )
+ else:
+ self.mid_adapter = None
+ self.post_sa_adapter = None
+ self.post_conv_adapter = None
+ self.adapter = None
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ chunk_size: int = -1,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif not self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if self.use_adapters and self.post_sa_adapter is not None:
+ src = self.post_sa_adapter(src)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ if self.use_adapters and self.mid_adapter is not None:
+ src = self.mid_adapter(src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if self.use_adapters and self.post_conv_adapter is not None:
+ src = self.post_conv_adapter(src)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ if self.use_adapters and self.adapter is not None:
+ src = self.adapter(src)
+
+ return src
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ cached_nonlin_attn: Tensor,
+ cached_val1: Tensor,
+ cached_val2: Tensor,
+ cached_conv1: Tensor,
+ cached_conv2: Tensor,
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Pass the input through the encoder layer in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
+ (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ cached_val1: cached left context for the first attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_val2: cached left context for the second attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_conv1: cached left context for the first convolution module,
+ of shape (batch_size, channels, left_pad)
+ cached_conv2: cached left context for the second convolution module,
+ of shape (batch_size, channels, left_pad)
+ left_context_len: number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - x, with the same shape as src
+ - updated cached_key
+ - updated cached_nonlin_attn
+ - updated cached_val1
+ - updated cached_val2
+ - updated cached_conv1
+ - updated cached_conv2
+ """
+ src_orig = src
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights, cached_key = self.self_attn_weights.streaming_forward(
+ src,
+ pos_emb=pos_emb,
+ cached_key=cached_key,
+ left_context_len=left_context_len,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
+ src,
+ attn_weights[0:1],
+ cached_x=cached_nonlin_attn,
+ left_context_len=left_context_len,
+ )
+ src = src + na
+
+ self_attn, cached_val1 = self.self_attn1.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val1,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ if self.use_adapters and self.post_sa_adapter is not None:
+ src = self.post_sa_adapter(src)
+
+ src_conv, cached_conv1 = self.conv_module1.streaming_forward(
+ src,
+ cache=cached_conv1,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward2(src)
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ if self.use_adapters and self.mid_adapter is not None:
+ src = self.mid_adapter(src)
+
+ self_attn, cached_val2 = self.self_attn2.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val2,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv2 = self.conv_module2.streaming_forward(
+ src,
+ cache=cached_conv2,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ if self.use_adapters and self.post_conv_adapter is not None:
+ src = self.post_conv_adapter(src)
+
+ src = src + self.feed_forward3(src)
+
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ if self.use_adapters and self.adapter is not None:
+ src = self.adapter(src)
+
+ return (
+ src,
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ )
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ pos_dim: int,
+ dropout: float,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ output = src
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ chunk_size=chunk_size,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ return output
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ pos_emb = self.encoder_pos(src, left_context_len)
+ output = src
+
+ new_states = []
+ for i, mod in enumerate(self.layers):
+ (
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ) = states[i * 6 : (i + 1) * 6]
+ (
+ output,
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ) = mod.streaming_forward(
+ output,
+ pos_emb,
+ cached_key=cached_key,
+ cached_nonlin_attn=cached_nonlin_attn,
+ cached_val1=cached_val1,
+ cached_val2=cached_val2,
+ cached_conv1=cached_conv1,
+ cached_conv2=cached_conv2,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ new_states += [
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ]
+
+ return output, new_states
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+ layer-skipping. The bypass is limited during early stages of training to be close to
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
+ force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 correponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
+ # on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+ after convolutional downsampling, and then upsampled again at the output, and combined
+ with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(
+ self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
+ ):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(dim, downsample, dropout)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(dim, downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+
+ src = self.encoder(
+ src,
+ chunk_size=chunk_size // ds,
+ feature_mask=feature_mask,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Downsample, go through encoder, upsample, in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
+ True means masked position. May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ src_orig = src
+ src = self.downsample(src)
+
+ src, new_states = self.encoder.streaming_forward(
+ src,
+ states=states,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src), new_states
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum, and a projection..
+ """
+
+ def __init__(self, channels: int, downsample: int, dropout: FloatLike):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+ self.dropout = copy.deepcopy(dropout)
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that mostly just repeats the input, but
+ also adds a position-specific bias.
+ """
+
+ def __init__(self, num_channels: int, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
+ the important information about the relative position in a relatively small number of dimensions.
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+ make very little difference to the embedding. Such differences were potentially important
+ when encoding absolute position, but not important when encoding relative position because there
+ is now no need to compare two large offsets with each other.
+
+ Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
+ using the atan() function, before doing the fourier transform of that fixed interval. The
+ atan() function would compress the "long tails" too small,
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+ function to compress large offsets to a smaller range before applying atan().
+ Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+ # for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+ # but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+ # is important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position encoding.
+ Various other modules consume the resulting attention weights: see, for example, the
+ SimpleAttention module which allows you to compute conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ self.in_proj = ScaledLinear(
+ embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be suffixient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnosics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ left_context_len: int,
+ key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ left_context_len: number of left context frames.
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+
+ Returns:
+ - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ - updated cached attention key tensor of left context.
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ # Pad cached left contexts
+ assert cached_key.shape[0] == left_context_len, (
+ cached_key.shape[0],
+ left_context_len,
+ )
+ k = torch.cat([cached_key, k], dim=0)
+ # Update cached left contexts
+ cached_key = k[-left_context_len:, ...]
+
+ # The length of key
+ k_len = k.shape[0]
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1 + left_context_len
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(k_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, k_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ assert attn_scores.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ k_len,
+ ), attn_scores.shape
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ attn_weights = attn_scores.softmax(dim=-1)
+
+ return attn_weights, cached_key
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.info(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed attention
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ ) -> None:
+ super().__init__()
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_val: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ cached_val: cached attention value tensor of left context,
+ of shape (left_context_len, batch_size, value_dim)
+ left_context_len: number of left context frames.
+
+ Returns:
+ - attention weighted output, a tensor with the same shape as x.
+ - updated cached attention value tensor of left context.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ seq_len2 = seq_len + left_context_len
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+
+ # Pad cached left contexts
+ assert cached_val.shape[0] == left_context_len, (
+ cached_val.shape[0],
+ left_context_len,
+ )
+ x = torch.cat([cached_val, x], dim=0)
+ # Update cached left contexts
+ cached_val = x[-left_context_len:, ...]
+
+ x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+
+ return x, cached_val
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in Zipformer2 model."""
+
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
+ one after the attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
+ # before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_x: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ cached_x: left context, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ left_context_len: number of left context frames.
+ Returns:
+ - a Tensor with the same shape as x
+ - updated left context with same shape as cached_x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = x * s
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ left_context_len + seq_len,
+ )
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+
+ # Pad cached tensor
+ assert cached_x.shape[2] == left_context_len, (
+ cached_x.shape[2],
+ left_context_len,
+ )
+ x_pad = torch.cat([cached_x, x], dim=2)
+ # Update cached tensor
+ cached_x = x_pad[:, :, -left_context_len:, :]
+
+ x = torch.matmul(attn_weights, x_pad)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ x = x * y
+
+ x = self.out_proj(x)
+ return x, cached_x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ causal: bool,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+ self.causal = causal
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = (
+ ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
+ if causal
+ else nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ chunk_size: int = -1,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ if (
+ not torch.jit.is_scripting()
+ and not torch.jit.is_tracing()
+ and chunk_size >= 0
+ ):
+ # Not support exporting a model for simulated streaming decoding
+ assert (
+ self.causal
+ ), "Must initialize model with causal=True if you use chunk_size"
+ x = self.depthwise_conv(x, chunk_size=chunk_size)
+ else:
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module in streaming forward mode.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ cache: cached left context for depthwise_conv of shape
+ (#batch, channels, left_pad)
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ - Output tensor (#time, batch, channels).
+ - Updated cache (#batch, channels, left_pad)
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.sigmoid(s)
+ x = x * s
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
+
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x, cache
+
+
+class ScalarMultiply(nn.Module):
+ def __init__(self, scale: float):
+ super().__init__()
+ self.scale = scale
+
+ def forward(self, x):
+ return x * self.scale
+
+
+def _test_zipformer_main(causal: bool = False):
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+
+ c = Zipformer2(
+ encoder_dim=(64, 96),
+ encoder_unmasked_dim=(48, 64),
+ num_heads=(4, 4),
+ causal=causal,
+ chunk_size=(4,) if causal else (-1,),
+ left_context_frames=(64,),
+ )
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f[0].sum().backward()
+ c.eval()
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f # to remove flake8 warnings
+
+
+class AdapterModule(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int = 384,
+ bottleneck_dim: int = 16,
+ ):
+ # The simplest adapter
+ super(AdapterModule, self).__init__()
+ self.embed_dim = embed_dim
+ self.bottleneck_dim = bottleneck_dim
+ self.activation = SwooshL()
+
+ self.in_proj = nn.Linear(embed_dim, bottleneck_dim)
+ self.out_proj = nn.Linear(bottleneck_dim, embed_dim)
+
+ def forward(self, x):
+ x_orig = x
+ x = self.activation(self.in_proj(x))
+ x = self.out_proj(x)
+ return x_orig + x
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_zipformer_main(False)
+ _test_zipformer_main(True)
diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py
index 60990456d..60112a84e 100755
--- a/egs/librispeech/ASR/zipformer_ctc/train.py
+++ b/egs/librispeech/ASR/zipformer_ctc/train.py
@@ -62,6 +62,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -797,9 +798,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/beam_search.py b/egs/librispeech/ASR/zipformer_lora/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
new file mode 100755
index 000000000..4d93a905f
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
@@ -0,0 +1,1115 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/decoder.py b/egs/librispeech/ASR/zipformer_lora/decoder.py
new file mode 120000
index 000000000..cab465d2b
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/decoder.py
@@ -0,0 +1 @@
+../zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/encoder_interface.py b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py
new file mode 120000
index 000000000..aa5d0217a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py
@@ -0,0 +1 @@
+../transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/export.py b/egs/librispeech/ASR/zipformer_lora/export.py
new file mode 100755
index 000000000..d47666bef
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/export.py
@@ -0,0 +1,543 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+- For non-streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+- For streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
+You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
+
+Check ./jit_pretrained_streaming.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+- For non-streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+- For streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --causal 1 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+- For non-streaming model:
+
+To use the generated file with `zipformer_lora/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./zipformer_lora/decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+- For streaming model:
+
+To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+
+ # simulated streaming decoding
+ ./zipformer_lora/decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+ # chunk-wise streaming decoding
+ ./zipformer_lora/streaming_decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+- non-streaming model:
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+
+- streaming model:
+https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
+ # You will find the pre-trained models in exp dir
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_lora/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class StreamingEncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ assert len(encoder.chunk_size) == 1, encoder.chunk_size
+ assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ self.pad_length = 7 + 2 * 3
+
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """Streaming forward for encoder_embed and encoder.
+
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ states: a list of Tensors
+
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ chunk_size = self.chunk_size
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lengths,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ # if torch.cuda.is_available():
+ # device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ # merge the LoRA weights
+ model.eval()
+
+ params.use_lora = False
+ base_model = get_model(params)
+
+ new_state_dict = {}
+ state_dict = model.state_dict()
+ param_names = base_model.state_dict().keys()
+ for k in param_names:
+ assert k in state_dict.keys()
+ new_state_dict[k] = state_dict[k]
+
+ base_model.load_state_dict(new_state_dict, strict=True)
+
+ model = base_model
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ # Wrap encoder and encoder_embed as a module
+ if params.causal:
+ model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
+ chunk_size = model.encoder.chunk_size
+ left_context_len = model.encoder.left_context_len
+ filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
+ else:
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py
new file mode 100755
index 000000000..0464cf65c
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/finetune.py
@@ -0,0 +1,1553 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# Fine-tune without mux (i.e not mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 0 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+# Fine-tune without mux (i.e mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 1 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ # Note that we add a very large constant here to make the ScheduledFloat
+ # variable as their end value.
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ ) + 100000
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--do-finetune",
+ type=str2bool,
+ default=True,
+ help="If true, finetune from a pre-trained checkpoint",
+ )
+
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune. This is useful
+ if you want to maintain the performance on the original domain
+ """,
+ )
+
+ parser.add_argument(
+ "--use-lora", type=str2bool, default=True, help="If use LoRA for fine-tune"
+ )
+
+ parser.add_argument(
+ "--lora-r", type=int, default=0, help="The bottleneck dimension of LoRA"
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.045,
+ help="""The base learning rate.
+ It is set to a very small value as we are doing fine-tuning""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000.0,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100.0,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ use_lora=params.use_lora,
+ lora_r=params.lora_r if params.use_lora else 0,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+ init_modules (list[str]): List of modules to be initialized
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dls: torch.utils.data.DataLoader,
+ valid_sets: List[str],
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ for name, m in model.named_modules():
+ if "lora" in name:
+ m.training = True
+ else:
+ m.training = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ for valid_set, valid_dl in zip(valid_sets, valid_dls):
+ logging.info(f"Computing validation loss on {valid_set}")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ logging.info(
+ f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
+ )
+ model.train()
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
+ )
+ # Need to update the model_avg if use initialisation
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ else:
+ # resuming training
+ assert params.start_epoch > 1, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ # keep the original model untouched, only update the adapters
+ num_trainable = 0
+ for name, p in model.named_parameters():
+ if "lora_A" in name or "lora_B" in name:
+ p.requires_grad = True
+ num_trainable += p.numel()
+ else:
+ p.requires_grad = False
+
+ logging.info(
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
+ num_trainable, num_trainable / num_param * 100
+ )
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
+ if params.use_mux:
+ librispeech_cuts = librispeech.train_all_shuf_cuts()
+ train_cuts = CutSet.mux(
+ gigaspeech_cuts, # num cuts = 688182
+ librispeech_cuts, # num cuts = 843723
+ weights=[688182, 843723],
+ stop_early=True,
+ )
+ else:
+ train_cuts = gigaspeech_cuts
+ logging.info(train_cuts)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+
+ valid_sets = ["librispeech", "gigaspeech"]
+ valid_dls = [
+ librispeech.valid_dataloaders(valid_cuts),
+ librispeech.valid_dataloaders(gigaspeech_dev_cuts),
+ ]
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # sp=sp,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dls=valid_dls,
+ valid_sets=valid_sets,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/joiner.py b/egs/librispeech/ASR/zipformer_lora/joiner.py
new file mode 120000
index 000000000..444cb5f15
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/joiner.py
@@ -0,0 +1 @@
+../zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/model.py b/egs/librispeech/ASR/zipformer_lora/model.py
new file mode 120000
index 000000000..0c6fe6112
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/model.py
@@ -0,0 +1 @@
+../zipformer/model.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/optim.py b/egs/librispeech/ASR/zipformer_lora/optim.py
new file mode 120000
index 000000000..207eecfcd
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/optim.py
@@ -0,0 +1 @@
+../zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py
new file mode 100644
index 000000000..3149db9f3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/scaling.py
@@ -0,0 +1,2052 @@
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import math
+import random
+from typing import Optional, Tuple, Union
+
+import k2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
+ max_value = torch.max(x, y)
+ diff = torch.abs(x - y)
+ return max_value + torch.log1p(torch.exp(-diff))
+
+
+# RuntimeError: Exporting the operator logaddexp to ONNX opset version
+# 14 is not supported. Please feel free to request support or submit
+# a pull request on PyTorch GitHub.
+#
+# The following function is to solve the above error when exporting
+# models to ONNX via torch.jit.trace()
+def logaddexp(x: Tensor, y: Tensor) -> Tensor:
+ # Caution(fangjun): Put torch.jit.is_scripting() before
+ # torch.onnx.is_in_onnx_export();
+ # otherwise, it will cause errors for torch.jit.script().
+ #
+ # torch.logaddexp() works for both torch.jit.script() and
+ # torch.jit.trace() but it causes errors for ONNX export.
+ #
+ if torch.jit.is_scripting():
+ # Note: We cannot use torch.jit.is_tracing() here as it also
+ # matches torch.onnx.export().
+ return torch.logaddexp(x, y)
+ elif torch.onnx.is_in_onnx_export():
+ return logaddexp_onnx(x, y)
+ else:
+ # for torch.jit.trace()
+ return torch.logaddexp(x, y)
+
+
+class PiecewiseLinear(object):
+ """
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
+ respectively.
+ """
+
+ def __init__(self, *args):
+ assert len(args) >= 1, len(args)
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
+ self.pairs = list(args[0].pairs)
+ else:
+ self.pairs = [(float(x), float(y)) for x, y in args]
+ for x, y in self.pairs:
+ assert isinstance(x, (float, int)), type(x)
+ assert isinstance(y, (float, int)), type(y)
+
+ for i in range(len(self.pairs) - 1):
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
+ i,
+ self.pairs[i],
+ self.pairs[i + 1],
+ )
+
+ def __str__(self):
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
+
+ def __call__(self, x):
+ if x <= self.pairs[0][0]:
+ return self.pairs[0][1]
+ elif x >= self.pairs[-1][0]:
+ return self.pairs[-1][1]
+ else:
+ cur_x, cur_y = self.pairs[0]
+ for i in range(1, len(self.pairs)):
+ next_x, next_y = self.pairs[i]
+ if x >= cur_x and x <= next_x:
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
+ cur_x, cur_y = next_x, next_y
+ assert False
+
+ def __mul__(self, alpha):
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
+
+ def __add__(self, x):
+ if isinstance(x, (float, int)):
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
+ s, x = self.get_common_basis(x)
+ return PiecewiseLinear(
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def max(self, x):
+ if isinstance(x, (float, int)):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def min(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def __eq__(self, other):
+ return self.pairs == other.pairs
+
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
+ """
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
+ functions to self and p, but with the same x values.
+
+ p: the other piecewise linear function
+ include_crossings: if true, include in the x values positions
+ where the functions indicate by this and p crosss.
+ """
+ assert isinstance(p, PiecewiseLinear), type(p)
+
+ # get sorted x-values without repetition.
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+
+ if include_crossings:
+ extra_x_vals = []
+ for i in range(len(x_vals) - 1):
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
+ # if the two lines in this subsegment potentially cross each other..
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
+ # `pos`, between 0 and 1, gives the relative x position,
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
+ pos = diff_cur / (diff_cur + diff_next)
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
+ extra_x_vals.append(extra_x_val)
+ if len(extra_x_vals) > 0:
+ x_vals = sorted(set(x_vals + extra_x_vals))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+ return (
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
+ )
+
+
+class ScheduledFloat(torch.nn.Module):
+ """
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
+ it does not have a working forward() function. You are supposed to cast it to float, as
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
+
+ It is a floating point value whose value changes depending on the batch count of the
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
+ first x or after the last x, we just use the first or last y value.
+
+ Example:
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
+
+ `default` is used when self.batch_count is not set or not in training mode or in
+ torch.jit scripting mode.
+ """
+
+ def __init__(self, *args, default: float = 0.0):
+ super().__init__()
+ # self.batch_count and self.name will be written to in the training loop.
+ self.batch_count = None
+ self.name = None
+ self.default = default
+ self.schedule = PiecewiseLinear(*args)
+
+ def extra_repr(self) -> str:
+ return (
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
+ )
+
+ def __float__(self):
+ batch_count = self.batch_count
+ if (
+ batch_count is None
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return float(self.default)
+ else:
+ ans = self.schedule(self.batch_count)
+ if random.random() < 0.0002:
+ logging.info(
+ f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
+ )
+ return ans
+
+ def __add__(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule + x, default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule + x.schedule, default=self.default + x.default
+ )
+
+ def max(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule.max(x.schedule), default=max(self.default, x.default)
+ )
+
+
+FloatLike = Union[float, ScheduledFloat]
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+ """
+ A randomized way of casting a floating point value to half precision.
+ """
+ if x.dtype == torch.float16:
+ return x
+ x_abs = x.abs()
+ is_too_small = x_abs < min_abs
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
+ # for those elements].
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class CutoffEstimator:
+ """
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
+ proportion of items will be above the cutoff on average.
+
+ p is the proportion of items that should be above the cutoff.
+ """
+
+ def __init__(self, p: float):
+ self.p = p
+ # total count of items
+ self.count = 0
+ # total count of items that were above the cutoff
+ self.count_above = 0
+ # initial cutoff value
+ self.cutoff = 0
+
+ def __call__(self, x: float) -> bool:
+ """
+ Returns true if x is above the cutoff.
+ """
+ ans = x > self.cutoff
+ self.count += 1
+ if ans:
+ self.count_above += 1
+ cur_p = self.count_above / self.count
+ delta_p = cur_p - self.p
+ if (delta_p > 0) == ans:
+ q = abs(delta_p)
+ self.cutoff = x * q + self.cutoff * (1 - q)
+ return ans
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.cuda.amp.autocast(enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim=dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ coeffs: Tensor,
+ direction: Tensor,
+ channel_dim: int,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.channel_dim = channel_dim
+ ctx.grad_scale = grad_scale
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad, *args):
+ with torch.enable_grad():
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
+ x_orig.requires_grad = True
+ num_channels = x_orig.shape[ctx.channel_dim]
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+ new_direction.requires_grad = False
+ x = x - x.mean(dim=0)
+ x_var = (x**2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual**2).mean()
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction. This is to be minimized.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+ variance_proportion.backward()
+ x_orig_grad = x_orig.grad
+ x_extra_grad = (
+ x_orig.grad
+ * ctx.grad_scale
+ * x_grad.norm()
+ / (x_orig_grad.norm() + 1.0e-20)
+ )
+ return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BiasNormFunction(torch.autograd.Function):
+ # This computes:
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
+ # return x * scales
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
+ # it can just store the returned value (chances are, this will also be needed for
+ # some other reason, related to the next operation, so we can save memory).
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ bias: Tensor,
+ log_scale: Tensor,
+ channel_dim: int,
+ store_output_for_backprop: bool,
+ ) -> Tensor:
+ assert bias.ndim == 1
+ if channel_dim < 0:
+ channel_dim = channel_dim + x.ndim
+ ctx.store_output_for_backprop = store_output_for_backprop
+ ctx.channel_dim = channel_dim
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ctx.save_for_backward(
+ ans.detach() if store_output_for_backprop else x,
+ scales.detach(),
+ bias.detach(),
+ log_scale.detach(),
+ )
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
+ if ctx.store_output_for_backprop:
+ x = ans_or_x / scales
+ else:
+ x = ans_or_x
+ x = x.detach()
+ x.requires_grad = True
+ bias.requires_grad = True
+ log_scale.requires_grad = True
+ with torch.enable_grad():
+ # recompute scales from x, bias and log_scale.
+ scales = (
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ans.backward(gradient=ans_grad)
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
+
+
+class BiasNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ Instead, we give the BiasNorm a trainable bias that it can use when
+ computing the scale for normalization. We also give it a (scalar)
+ trainable scale on the output.
+
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interpreted as an offset from the input's ndim if negative.
+ This is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ log_scale: the initial log-scale that we multiply the output by; this
+ is learnable.
+ log_scale_min: FloatLike, minimum allowed value of log_scale
+ log_scale_max: FloatLike, maximum allowed value of log_scale
+ store_output_for_backprop: only possibly affects memory use; recommend
+ to set to True if you think the output of this module is more likely
+ than the input of this module to be required to be stored for the
+ backprop.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ log_scale: float = 1.0,
+ log_scale_min: float = -1.5,
+ log_scale_max: float = 1.5,
+ store_output_for_backprop: bool = False,
+ ) -> None:
+ super(BiasNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+
+ self.log_scale_min = log_scale_min
+ self.log_scale_max = log_scale_max
+
+ self.store_output_for_backprop = store_output_for_backprop
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ channel_dim = self.channel_dim
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ bias = self.bias
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * self.log_scale.exp()
+ return x * scales
+
+ log_scale = limit_param_value(
+ self.log_scale,
+ min=float(self.log_scale_min),
+ max=float(self.log_scale_max),
+ training=self.training,
+ )
+
+ return BiasNormFunction.apply(
+ x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
+ )
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class LoRALayer:
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.0:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+
+class ScaledLinear_lora(nn.Linear, LoRALayer):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ fan_in_fan_out: bool = False,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ initial_scale: float = 1.0,
+ merge_weights: bool = True,
+ **kwargs,
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(
+ self,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ merge_weights=merge_weights,
+ )
+
+ self.initial_scale = initial_scale
+ self.fan_in_fan_out = fan_in_fan_out
+ if r > 0:
+ self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0))
+ self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0))
+ self.scaling = self.lora_alpha / self.r
+ self.weight.requires_grad = False
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # initialize the parameters
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, "lora_A"):
+ initial_scale = self.initial_scale
+ with torch.no_grad():
+ self.weight[:] *= initial_scale
+ if self.bias is not None:
+ nn.init.uniform_(
+ self.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+ if hasattr(self, "lora_A"):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ nn.Linear.train(self, mode)
+ if mode:
+ # We don't want the weights to be merged in training mode
+ if self.merge_weights and self.merged:
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+ else:
+ # When evaluating the model, we merge the weights for simplicity
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ delta_result = (
+ self.lora_dropout(x)
+ @ self.lora_A.transpose(0, 1)
+ @ self.lora_B.transpose(0, 1)
+ )
+ return result + delta_result * self.scaling
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv1d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv1d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv2d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False, but:
+ NO PADDING-RELATED ARGS.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv2d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class ChunkCausalDepthwiseConv1d(torch.nn.Module):
+ """
+ Behaves like a depthwise 1d convolution, except that it is causal in
+ a chunkwise way, as if we had a block-triangular attention mask.
+ The chunk size is provided at test time (it should probably be
+ kept in sync with the attention mask).
+
+ This has a little more than twice the parameters of a conventional
+ depthwise conv1d module: we implement it by having one
+ depthwise convolution, of half the width, that is causal (via
+ right-padding); and one depthwise convolution that is applied only
+ within chunks, that we multiply by a scaling factor which depends
+ on the position within the chunk.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ initial_scale: float = 1.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ assert kernel_size % 2 == 1
+
+ half_kernel_size = (kernel_size + 1) // 2
+ # will pad manually, on one side.
+ self.causal_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=half_kernel_size,
+ padding=0,
+ bias=True,
+ )
+
+ self.chunkwise_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ bias=bias,
+ )
+
+ # first row is correction factors added to the scale near the left edge of the chunk,
+ # second row is correction factors added to the scale near the right edge of the chunk,
+ # both of these are added to a default scale of 1.0.
+ self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
+ self.kernel_size = kernel_size
+
+ with torch.no_grad():
+ self.causal_conv.weight[:] *= initial_scale
+ self.chunkwise_conv.weight[:] *= initial_scale
+ if bias:
+ torch.nn.init.uniform_(
+ self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+
+ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
+ """
+ Forward function. Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # half_kernel_size = self.kernel_size + 1 // 2
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ if chunk_size < 0 or chunk_size > seq_len:
+ chunk_size = seq_len
+ right_pad = -seq_len % chunk_size
+
+ x = torch.nn.functional.pad(x, (left_pad, right_pad))
+
+ x_causal = self.causal_conv(x[..., : left_pad + seq_len])
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ num_chunks = x_chunk.shape[2] // chunk_size
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
+ x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
+ batch_size * num_chunks, num_channels, chunk_size
+ )
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size)
+
+ x_chunk = x_chunk * chunk_scale
+ x_chunk = x_chunk.reshape(
+ batch_size, num_chunks, num_channels, chunk_size
+ ).permute(0, 2, 1, 3)
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
+ ..., :seq_len
+ ]
+
+ return x_chunk + x_causal
+
+ def _get_chunk_scale(self, chunk_size: int):
+ """Returns tensor of shape (num_channels, chunk_size) that will be used to
+ scale the output of self.chunkwise_conv."""
+ left_edge = self.chunkwise_conv_scale[0]
+ right_edge = self.chunkwise_conv_scale[1]
+ if chunk_size < self.kernel_size:
+ left_edge = left_edge[:, :chunk_size]
+ right_edge = right_edge[:, -chunk_size:]
+ else:
+ t = chunk_size - self.kernel_size
+ channels = left_edge.shape[0]
+ pad = torch.zeros(
+ channels, t, device=left_edge.device, dtype=left_edge.dtype
+ )
+ left_edge = torch.cat((left_edge, pad), dim=-1)
+ right_edge = torch.cat((pad, right_edge), dim=-1)
+ return 1.0 + (left_edge + right_edge)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Streaming Forward function.
+
+ Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ cache: cached left context of shape (batch_size, channels, left_pad)
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ # Pad cache
+ assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
+ x = torch.cat([cache, x], dim=2)
+ # Update cache
+ cache = x[..., -left_pad:]
+
+ x_causal = self.causal_conv(x)
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
+ x_chunk = x_chunk * chunk_scale
+
+ return x_chunk + x_causal, cache
+
+
+class BalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min_mean: float,
+ max_mean: float,
+ min_rms: float,
+ max_rms: float,
+ grad_scale: float,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ ctx.save_for_backward(x)
+ ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
+ (x,) = ctx.saved_tensors
+ (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
+
+ try:
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x = x.to(torch.float32)
+ x = x.detach()
+ x.requires_grad = True
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
+ mean = x.mean(dim=mean_dims, keepdim=True)
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
+
+ m = mean / stddev
+ # part of loss that relates to mean / stddev
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
+
+ # put a much larger scale on the RMS-max-limit loss, so that if both it and the
+ # m_loss are violated we fix the RMS loss first.
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
+ r_loss = (rms_clamped / rms).log().abs()
+
+ loss = m_loss + r_loss
+
+ loss.backward(gradient=torch.ones_like(loss))
+ loss_grad = x.grad
+ loss_grad_rms = (
+ (loss_grad**2)
+ .mean(dim=mean_dims, keepdim=True)
+ .sqrt()
+ .clamp(min=1.0e-20)
+ )
+
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
+
+ x_grad_float = x_grad.to(torch.float32)
+ # scale each element of loss_grad by the absolute value of the corresponding
+ # element of x_grad, which we view as a noisy estimate of its magnitude for that
+ # (frame and dimension). later we can consider factored versions.
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
+ x_grad = x_grad_mod.to(x_grad.dtype)
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+
+ return x_grad, None, None, None, None, None, None
+
+
+class Balancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: FloatLike = 0.05,
+ max_positive: FloatLike = 0.95,
+ min_abs: FloatLike = 0.2,
+ max_abs: FloatLike = 100.0,
+ grad_scale: FloatLike = 0.04,
+ prob: Optional[FloatLike] = None,
+ ):
+ super().__init__()
+
+ if prob is None:
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
+ self.prob = prob
+ # 5% of the time we will return and do nothing because memory usage is
+ # too high.
+ self.mem_cutoff = CutoffEstimator(0.05)
+
+ # actually self.num_channels is no longer needed except for an assertion.
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or not x.requires_grad
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
+ ):
+ return _no_op(x)
+
+ prob = float(self.prob)
+ if random.random() < prob:
+ # The following inner-functions convert from the way we historically specified
+ # these limitations, as limits on the absolute value and the proportion of positive
+ # values, to limits on the RMS value and the (mean / stddev).
+ def _abs_to_rms(x):
+ # for normally distributed data, if the expected absolute value is x, the
+ # expected rms value will be sqrt(pi/2) * x.
+ return 1.25331413732 * x
+
+ def _proportion_positive_to_mean(x):
+ def _atanh(x):
+ eps = 1.0e-10
+ # eps is to prevent crashes if x is exactly 0 or 1.
+ # we'll just end up returning a fairly large value.
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
+
+ def _approx_inverse_erf(x):
+ # 1 / (sqrt(pi) * ln(2)),
+ # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
+ # this approximation is extremely crude and gets progressively worse for
+ # x very close to -1 or +1, but we mostly care about the "middle" region
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
+ # which is pretty close to 0.05.
+ return 0.8139535143 * _atanh(x)
+
+ # first convert x from the range 0..1 to the range -1..1 which the error
+ # function returns
+ x = -1 + (2 * x)
+ return _approx_inverse_erf(x)
+
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
+ min_rms = _abs_to_rms(float(self.min_abs))
+ max_rms = _abs_to_rms(float(self.max_abs))
+ grad_scale = float(self.grad_scale)
+
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ return BalancerFunction.apply(
+ x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(
+ x: Tensor, limit: float, penalty: float, name: str = None
+) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+
+ The name is for randomly printed debug info.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss, name)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance matrix
+ and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.module = module
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ w = ctx.module
+
+ try:
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, w.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.info(
+ f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
+ )
+
+ if metric < float(w.whitening_limit):
+ w.prob = w.min_prob
+ return x_grad, None
+ else:
+ w.prob = w.max_prob
+ metric.backward()
+ penalty_grad = x_detached.grad
+ scale = w.grad_scale * (
+ x_grad.to(torch.float32).norm()
+ / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+ return x_grad, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: FloatLike,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: FloatLike,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert float(whitening_limit) >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ self.grad_scale = grad_scale
+
+ if isinstance(prob, float):
+ prob = (prob, prob)
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob <= self.max_prob <= 1
+ self.prob = self.max_prob
+ self.name = None # will be set in training loop
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ grad_scale = float(self.grad_scale)
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
+ return _no_op(x)
+ else:
+ return WhiteningPenaltyFunction.apply(x, self)
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
+ ctx.y_shape = y.shape
+ if random.random() < 0.002 and name is not None:
+ loss_sum = y.sum().item()
+ logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return (
+ ans_grad,
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
+ None,
+ )
+
+
+def with_loss(x, y, name):
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y, name)
+
+
+class ScaleGradFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, alpha: float) -> Tensor:
+ ctx.alpha = alpha
+ return x
+
+ @staticmethod
+ def backward(ctx, grad: Tensor):
+ return grad * ctx.alpha, None
+
+
+def scale_grad(x: Tensor, alpha: float):
+ return ScaleGradFunction.apply(x, alpha)
+
+
+class ScaleGrad(nn.Module):
+ def __init__(self, alpha: float):
+ super().__init__()
+ self.alpha = alpha
+
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return x
+ return scale_grad(x, self.alpha)
+
+
+class LimitParamValue(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, min: float, max: float):
+ ctx.save_for_backward(x)
+ assert max >= min
+ ctx.min = min
+ ctx.max = max
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x,) = ctx.saved_tensors
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
+ # x more positive).
+ x_grad = x_grad * torch.where(
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
+ )
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
+ # x more negative).
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
+ return x_grad, None, None
+
+
+def limit_param_value(
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
+):
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
+ # (elements mostly) stays within a supplied range. This is done by modifying the
+ # gradients in backprop.
+ # It's not necessary to do this on every batch: do it only some of the time,
+ # to save a little time.
+ if training and random.random() < prob:
+ return LimitParamValue.apply(x, min, max)
+ else:
+ return x
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+
+ if requires_grad:
+ deriv = y * (1 - s) + s
+
+ # notes on derivative of x * sigmoid(x - 1):
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+ # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+ # floors), should be expectation-preserving.
+ floor = -0.044
+ ceil = 1.2
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ deriv
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.043637
+ ceil = 1.2
+
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ return DoubleSwishFunction.apply(x)
+
+
+# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
+class Dropout2(nn.Module):
+ def __init__(self, p: FloatLike):
+ super().__init__()
+ self.p = p
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
+
+
+class MulForDropout3(torch.autograd.Function):
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
+ # grad and is zero-or-one.
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, y, alpha):
+ assert not y.requires_grad
+ ans = x * y * alpha
+ ctx.save_for_backward(ans)
+ ctx.alpha = alpha
+ return ans
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad):
+ (ans,) = ctx.saved_tensors
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
+ return x_grad, None, None
+
+
+# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
+# and it lets you choose one dimension to share the dropout mask over
+class Dropout3(nn.Module):
+ def __init__(self, p: FloatLike, shared_dim: int):
+ super().__init__()
+ self.p = p
+ self.shared_dim = shared_dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ p = float(self.p)
+ if not self.training or p == 0:
+ return _no_op(x)
+ scale = 1.0 / (1 - p)
+ rand_shape = list(x.shape)
+ rand_shape[self.shared_dim] = 1
+ mask = torch.rand(*rand_shape, device=x.device) > p
+ ans = MulForDropout3.apply(x, mask, scale)
+ return ans
+
+
+class SwooshLFunction(torch.autograd.Function):
+ """
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ coeff = -0.08
+
+ with torch.cuda.amp.autocast(enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
+
+ if not requires_grad:
+ return y
+
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+
+ coeff = -0.08
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshL(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
+ if not x.requires_grad:
+ return k2.swoosh_l_forward(x)
+ else:
+ return k2.swoosh_l(x)
+ # return SwooshLFunction.apply(x)
+
+
+class SwooshLOnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
+
+
+class SwooshRFunction(torch.autograd.Function):
+ """
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
+
+ derivatives are between -0.08 and 0.92.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+ if not requires_grad:
+ return y
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = -0.08
+ ceil = 0.925
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.08
+ ceil = 0.925
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshR(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+ if not x.requires_grad:
+ return k2.swoosh_r_forward(x)
+ else:
+ return k2.swoosh_r(x)
+ # return SwooshRFunction.apply(x)
+
+
+class SwooshROnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+
+# simple version of SwooshL that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshLForward(x: Tensor):
+ x_offset = x - 4.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.035
+
+
+# simple version of SwooshR that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshRForward(x: Tensor):
+ x_offset = x - 1.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.313261687
+
+
+class ActivationDropoutAndLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ activation: str,
+ dropout_p: float,
+ dropout_shared_dim: Optional[int],
+ ):
+ if dropout_p != 0.0:
+ dropout_shape = list(x.shape)
+ if dropout_shared_dim is not None:
+ dropout_shape[dropout_shared_dim] = 1
+ # else it won't be very memory efficient.
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
+ )
+ else:
+ dropout_mask = None
+
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
+
+ ctx.activation = activation
+
+ forward_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward,
+ "SwooshR": k2.swoosh_r_forward,
+ }
+ # it will raise a KeyError if this fails. This will be an error. We let it
+ # propagate to the user.
+ activation_func = forward_activation_dict[activation]
+ x = activation_func(x)
+ if dropout_mask is not None:
+ x = x * dropout_mask
+ x = torch.nn.functional.linear(x, weight, bias)
+ return x
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad: Tensor):
+ saved = ctx.saved_tensors
+ (x, weight, bias, dropout_mask) = saved
+
+ forward_and_deriv_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
+ }
+ # the following lines a KeyError if the activation is unrecognized.
+ # This will be an error. We let it propagate to the user.
+ func = forward_and_deriv_activation_dict[ctx.activation]
+
+ y, func_deriv = func(x)
+ if dropout_mask is not None:
+ y = y * dropout_mask
+ # now compute derivative of y w.r.t. weight and bias..
+ # y: (..., in_channels), ans_grad: (..., out_channels),
+ (out_channels, in_channels) = weight.shape
+
+ in_channels = y.shape[-1]
+ g = ans_grad.reshape(-1, out_channels)
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
+ y_deriv = torch.matmul(ans_grad, weight)
+ bias_deriv = None if bias is None else g.sum(dim=0)
+ x_deriv = y_deriv * func_deriv
+ if dropout_mask is not None:
+ # order versus func_deriv does not matter
+ x_deriv = x_deriv * dropout_mask
+
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
+
+
+class ActivationDropoutAndLinear(torch.nn.Module):
+ """
+ This merges an activation function followed by dropout and then a nn.Linear module;
+ it does so in a memory efficient way so that it only stores the input to the whole
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
+ equivalent to:
+ nn.Sequential(SwooshL(),
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
+ ScaledLinear(in_channels, out_channels, bias=bias,
+ initial_scale=initial_scale))
+ If dropout_shared_dim is None, the dropout would be equivalent to
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
+ mask is smaller.
+
+ Args:
+ in_channels: number of input channels, e.g. 256
+ out_channels: number of output channels, e.g. 256
+ bias: if true, have a bias
+ activation: the activation function, for now just support SwooshL.
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
+ shared (e.g. the time dimension). If None, this may be less memory
+ efficient if there are modules before this one that cache the input
+ for their backprop (e.g. Balancer or Whiten).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ # create a temporary module of nn.Linear that we'll steal the
+ # weights and bias from
+ l = ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
+ )
+
+ self.weight = l.weight
+ # register_parameter properly handles making it a parameter when l.bias
+ # is None. I think there is some reason for doing it this way rather
+ # than just setting it to None but I don't know what it is, maybe
+ # something to do with exporting the module..
+ self.register_parameter("bias", l.bias)
+
+ self.activation = activation
+ self.dropout_p = dropout_p
+ self.dropout_shared_dim = dropout_shared_dim
+
+ def forward(self, x: Tensor):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ if self.activation == "SwooshL":
+ x = SwooshLForward(x)
+ elif self.activation == "SwooshR":
+ x = SwooshRForward(x)
+ else:
+ assert False, self.activation
+ return torch.nn.functional.linear(x, self.weight, self.bias)
+
+ return ActivationDropoutAndLinearFunction.apply(
+ x,
+ self.weight,
+ self.bias,
+ self.activation,
+ float(self.dropout_p),
+ self.dropout_shared_dim,
+ )
+
+
+class ActivationDropoutAndLinear_lora(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ self.l = ScaledLinear_lora(
+ in_features=in_channels,
+ out_features=out_channels,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=initial_scale,
+ bias=bias,
+ )
+ self.weight = self.l.weight
+ self.register_parameter("bias", self.l.bias)
+
+ if activation == "SwooshL":
+ self.activation = SwooshL()
+ elif activation == "SwooshR":
+ self.activation = SwooshR()
+ else:
+ assert False, activation
+ self.dropout = Dropout3(dropout_p, dropout_shared_dim)
+
+ def forward(self, x: Tensor):
+ return self.l(self.dropout(self.activation(x)))
+
+
+def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
+ if num_channels <= x.shape[-1]:
+ return x[..., :num_channels]
+ else:
+ shape = list(x.shape)
+ shape[-1] = num_channels - shape[-1]
+ zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
+ return torch.cat((x, zeros), dim=-1)
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ min_abs=0.0,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_sign: x = ", x)
+ print("_test_balancer_sign: y grad = ", y_grad)
+ print("_test_balancer_sign: x grad = ", x.grad)
+
+
+def _test_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ min_abs=0.2,
+ max_abs=0.7,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_magnitude: x = ", x)
+ print("_test_balancer_magnitude: y grad = ", y_grad)
+ print("_test_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = DoubleSwish()
+
+ tol = (1.2 - (-0.043637)) / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshl_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshL()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshr_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshR()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+def _test_piecewise_linear():
+ p = PiecewiseLinear((0, 10.0))
+ for x in [-100, 0, 100]:
+ assert p(x) == 10.0
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
+ print("x, y = ", x, y)
+ assert p(x) == y, (x, p(x), y)
+
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
+ pq = p.max(q)
+ for x in x_vals:
+ y1 = max(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p.min(q)
+ for x in x_vals:
+ y1 = min(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p + q
+ for x in x_vals:
+ y1 = p(x) + q(x)
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+
+
+def _test_activation_dropout_and_linear():
+ in_channels = 20
+ out_channels = 30
+
+ for bias in [True, False]:
+ # actually we don't test for dropout_p != 0.0 because forward functions will give
+ # different answers. This is because we are using the k2 implementation of
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
+ # internally, messing up the random state.
+ for dropout_p in [0.0]:
+ for activation in ["SwooshL", "SwooshR"]:
+ m1 = nn.Sequential(
+ SwooshL() if activation == "SwooshL" else SwooshR(),
+ Dropout3(p=dropout_p, shared_dim=-1),
+ ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=0.5
+ ),
+ )
+ m2 = ActivationDropoutAndLinear(
+ in_channels,
+ out_channels,
+ bias=bias,
+ initial_scale=0.5,
+ activation=activation,
+ dropout_p=dropout_p,
+ )
+ with torch.no_grad():
+ m2.weight[:] = m1[2].weight
+ if bias:
+ m2.bias[:] = m1[2].bias
+ # make sure forward gives same result.
+ x1 = torch.randn(10, in_channels)
+ x1.requires_grad = True
+
+ # TEMP.
+ assert torch.allclose(
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
+ )
+
+ x2 = x1.clone().detach()
+ x2.requires_grad = True
+ seed = 10
+ torch.manual_seed(seed)
+ y1 = m1(x1)
+ y_grad = torch.randn_like(y1)
+ y1.backward(gradient=y_grad)
+ torch.manual_seed(seed)
+ y2 = m2(x2)
+ y2.backward(gradient=y_grad)
+
+ print(
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
+ )
+ print("y1 = ", y1)
+ print("y2 = ", y2)
+ assert torch.allclose(y1, y2, atol=0.02)
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
+ if bias:
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
+ print("x1.grad = ", x1.grad)
+ print("x2.grad = ", x2.grad)
+
+ def isclose(a, b):
+ # return true if cosine similarity is > 0.9.
+ return (a * b).sum() > 0.9 * (
+ (a**2).sum() * (b**2).sum()
+ ).sqrt()
+
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
+ # storage of it.
+ assert isclose(x1.grad, x2.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_piecewise_linear()
+ _test_softmax()
+ _test_whiten()
+ _test_balancer_sign()
+ _test_balancer_magnitude()
+ _test_double_swish_deriv()
+ _test_swooshr_deriv()
+ _test_swooshl_deriv()
+ _test_activation_dropout_and_linear()
diff --git a/egs/librispeech/ASR/zipformer_lora/scaling_converter.py b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py
new file mode 120000
index 000000000..bc7c7b5e3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py
@@ -0,0 +1 @@
+../zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/subsampling.py b/egs/librispeech/ASR/zipformer_lora/subsampling.py
new file mode 120000
index 000000000..d178adc2e
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/subsampling.py
@@ -0,0 +1 @@
+../zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py
new file mode 100755
index 000000000..3ccf7d2f1
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/train.py
@@ -0,0 +1,1398 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --full-libri 1 \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --full-libri 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ if params.full_libri:
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py
new file mode 100644
index 000000000..43865609a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py
@@ -0,0 +1,2522 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from encoder_interface import EncoderInterface
+from scaling import (
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+ ActivationDropoutAndLinear,
+ ActivationDropoutAndLinear_lora,
+ Balancer,
+ BiasNorm,
+ ChunkCausalDepthwiseConv1d,
+ Dropout2,
+ FloatLike,
+ ScaledLinear_lora,
+ ScheduledFloat,
+ Whiten,
+ convert_num_channels,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+from torch import Tensor, nn
+
+
+class Zipformer2(EncoderInterface):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+ as downsampling_factor if they are single ints or one-element tuples. The length of
+ downsampling_factor defines the number of stacks.
+
+ output_downsampling_factor (int): how much to downsample at the output. Note:
+ we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
+ You should probably leave this at 2.
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+ encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
+ the encoder stacks for purposes of per-frame dropout (recommend 256 for
+ now).
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+ attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+ e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ causal (bool): if True, support chunkwise causal convolution. This should
+ not hurt WER as no modeling power is lost, but the convolution modules will be
+ slightly slower and use more memory. Enables use of the chunk_size and
+ left_context_chunks options in forward(), which simulates streaming
+ decoding.
+ chunk_size: (list of int): only set this to other than [-1] if causal;
+ the chunk size will be randomly chosen from this list. -1 means no chunking.
+ left_context_frames: (list of int): determines the number of left-
+ context chunks for causal training; will be rounded to a number of
+ chunks. Must not be less than cnn_module_kernel (after factoring in
+ rounding and downsampling); an error will be thrown if this is violated.
+ """
+
+ def __init__(
+ self,
+ output_downsampling_factor: int = 2,
+ downsampling_factor: Tuple[int] = (2, 4),
+ encoder_dim: Union[int, Tuple[int]] = 384,
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
+ query_head_dim: Union[int, Tuple[int]] = 24,
+ pos_head_dim: Union[int, Tuple[int]] = 4,
+ value_head_dim: Union[int, Tuple[int]] = 12,
+ num_heads: Union[int, Tuple[int]] = 8,
+ feedforward_dim: Union[int, Tuple[int]] = 1536,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ causal: bool = False,
+ chunk_size: Tuple[int] = [-1],
+ left_context_frames: Tuple[int] = [-1],
+ use_lora: bool = True,
+ lora_r: int = 0,
+ ) -> None:
+ super(Zipformer2, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
+ as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ self.output_downsampling_factor = output_downsampling_factor # int
+ self.downsampling_factor = downsampling_factor # tuple
+ self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
+ self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
+ encoder_unmasked_dim
+ ) # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
+ self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
+ pos_head_dim = _to_tuple(pos_head_dim)
+ self.num_heads = num_heads = _to_tuple(num_heads)
+ feedforward_dim = _to_tuple(feedforward_dim)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+
+ self.causal = causal
+ self.chunk_size = chunk_size
+ self.left_context_frames = left_context_frames
+
+ self.lora_r = lora_r if use_lora else 0
+
+ for u, d in zip(encoder_unmasked_dim, encoder_dim):
+ assert u <= d
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim[i],
+ pos_dim=pos_dim,
+ num_heads=num_heads[i],
+ query_head_dim=query_head_dim[i],
+ pos_head_dim=pos_head_dim[i],
+ value_head_dim=value_head_dim[i],
+ feedforward_dim=feedforward_dim[i],
+ dropout=dropout,
+ cnn_module_kernel=cnn_module_kernel[i],
+ causal=causal,
+ lora_r=self.lora_r,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ pos_dim=pos_dim,
+ dropout=dropout,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim[i],
+ downsample=downsampling_factor[i],
+ dropout=dropout,
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+
+ self.downsample_output = SimpleDownsample(
+ max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
+ )
+
+ def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
+ """
+ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
+ randomized feature masks, one per encoder.
+ On e.g. 15% of frames, these masks will zero out all enocder dims larger than
+ some supplied number, e.g. >256, so in effect on those frames we are using
+ a smaller encoer dim.
+
+ We generate the random masks at this level because we want the 2 masks to 'agree'
+ all the way up the encoder stack. This will mean that the 1st mask will have
+ mask values repeated self.zipformer_subsampling_factor times.
+
+ Args:
+ x: the embeddings (needed for the shape and dtype and device), of shape
+ (1, batch_size, encoder_dims0)
+ """
+ num_encoders = len(self.encoder_dim)
+ if not self.training:
+ return [1.0] * num_encoders
+
+ (num_frames0, batch_size, _encoder_dims0) = x.shape
+
+ assert self.encoder_dim[0] == _encoder_dims0, (
+ self.encoder_dim[0],
+ _encoder_dims0,
+ )
+
+ feature_mask_dropout_prob = 0.125
+
+ # mask1 shape: (1, batch_size, 1)
+ mask1 = (
+ torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
+ ).to(x.dtype)
+
+ # mask2 has additional sequences masked, about twice the number.
+ mask2 = torch.logical_and(
+ mask1,
+ (
+ torch.rand(1, batch_size, 1, device=x.device)
+ > feature_mask_dropout_prob
+ ).to(x.dtype),
+ )
+
+ # dim: (1, batch_size, 2)
+ mask = torch.cat((mask1, mask2), dim=-1)
+
+ feature_masks = []
+ for i in range(num_encoders):
+ channels = self.encoder_dim[i]
+ feature_mask = torch.ones(
+ 1, batch_size, channels, dtype=x.dtype, device=x.device
+ )
+ u1 = self.encoder_unmasked_dim[i]
+ u2 = u1 + (channels - u1) // 2
+
+ feature_mask[:, :, u1:u2] *= mask[..., 0:1]
+ feature_mask[:, :, u2:] *= mask[..., 1:2]
+
+ feature_masks.append(feature_mask)
+
+ return feature_masks
+
+ def get_chunk_info(self) -> Tuple[int, int]:
+ """
+ Returns chunk_size and left_context_chunks.
+ """
+ if not self.causal:
+ return -1, -1
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.chunk_size) == 1, self.chunk_size
+ chunk_size = self.chunk_size[0]
+ else:
+ chunk_size = random.choice(self.chunk_size)
+
+ if chunk_size == -1:
+ left_context_chunks = -1
+ else:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.left_context_frames) == 1, self.left_context_frames
+ left_context_frames = self.left_context_frames[0]
+ else:
+ left_context_frames = random.choice(self.left_context_frames)
+ # Note: in Python, -1 // n == -1 for n > 0
+ left_context_chunks = left_context_frames // chunk_size
+ if left_context_chunks == 0:
+ left_context_chunks = 1
+
+ return chunk_size, left_context_chunks
+
+ def forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ """
+ outputs = []
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ feature_masks = [1.0] * len(self.encoder_dim)
+ else:
+ feature_masks = self.get_feature_masks(x)
+
+ chunk_size, left_context_chunks = self.get_chunk_info()
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # Not support exporting a model for simulating streaming decoding
+ attn_mask = None
+ else:
+ attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
+
+ for i, module in enumerate(self.encoders):
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x = module(
+ x,
+ chunk_size=chunk_size,
+ feature_mask=feature_masks[i],
+ src_key_padding_mask=(
+ None
+ if src_key_padding_mask is None
+ else src_key_padding_mask[..., ::ds]
+ ),
+ attn_mask=attn_mask,
+ )
+ outputs.append(x)
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2, self.output_downsampling_factor
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths
+
+ def _get_attn_mask(
+ self, x: Tensor, chunk_size: int, left_context_chunks: int
+ ) -> Optional[Tensor]:
+ """
+ Return None if chunk_size == -1, else return attention mask of shape
+ (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
+ means a masked position.
+ Args:
+ x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
+ chunk_size: chunk size, must divide
+ """
+ if chunk_size <= 0:
+ return None
+ assert all(chunk_size % d == 0 for d in self.downsampling_factor)
+ if left_context_chunks >= 0:
+ num_encoders = len(self.encoder_dim)
+ assert all(
+ chunk_size * left_context_chunks
+ >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
+ for i in range(num_encoders)
+ )
+ else:
+ left_context_chunks = 1000000
+
+ seq_len = x.shape[0]
+
+ # t is frame index, shape (seq_len,)
+ t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
+ # c is chunk index for each frame, shape (seq_len,)
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ c = t // chunk_size
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ c = t // chunk_size
+ src_c = c
+ tgt_c = c.unsqueeze(-1)
+
+ attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
+ if __name__ == "__main__":
+ logging.info(f"attn_mask = {attn_mask}")
+ return attn_mask
+
+ def _get_full_dim_output(self, outputs: List[Tensor]):
+ num_encoders = len(self.encoder_dim)
+ assert len(outputs) == num_encoders
+ output_dim = max(self.encoder_dim)
+ output_pieces = [outputs[-1]]
+ cur_dim = self.encoder_dim[-1]
+ for i in range(num_encoders - 2, -1, -1):
+ d = self.encoder_dim[i]
+ if d > cur_dim:
+ this_output = outputs[i]
+ output_pieces.append(this_output[..., cur_dim:d])
+ cur_dim = d
+ assert cur_dim == output_dim
+ return torch.cat(output_pieces, dim=-1)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ states: List[Tensor],
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ states: list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ - updated states
+ """
+ outputs = []
+ new_states = []
+ layer_offset = 0
+
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x, new_layer_states = module.streaming_forward(
+ x,
+ states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
+ left_context_len=self.left_context_frames[0] // ds,
+ src_key_padding_mask=src_key_padding_mask[..., ::ds],
+ )
+ layer_offset += num_layers
+ outputs.append(x)
+ new_states += new_layer_states
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[Tensor]:
+ """Get initial states.
+
+ A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ """
+ states = []
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ embed_dim = self.encoder_dim[i]
+ ds = self.downsampling_factor[i]
+ num_heads = self.num_heads[i]
+ key_dim = self.query_head_dim[i] * num_heads
+ value_dim = self.value_head_dim[i] * num_heads
+ downsample_left = self.left_context_frames[0] // ds
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ conv_left_pad = self.cnn_module_kernel[i] // 2
+ for layer in range(num_layers):
+ cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
+ device
+ )
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, downsample_left, nonlin_attn_head_dim
+ ).to(device)
+ cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ return states
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+def _balancer_schedule(min_prob: float):
+ return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ causal: bool = False,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass; see its default values.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.self_attn1 = SelfAttention(
+ embed_dim,
+ num_heads,
+ value_head_dim,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.self_attn2 = SelfAttention(
+ embed_dim,
+ num_heads,
+ value_head_dim,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim,
+ (feedforward_dim * 3) // 4,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward2 = FeedforwardModule(
+ embed_dim,
+ feedforward_dim,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim,
+ (feedforward_dim * 5) // 4,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.conv_module1 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ self.conv_module2 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ # TODO: remove it
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ chunk_size: int = -1,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif not self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ return src
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ cached_nonlin_attn: Tensor,
+ cached_val1: Tensor,
+ cached_val2: Tensor,
+ cached_conv1: Tensor,
+ cached_conv2: Tensor,
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Pass the input through the encoder layer in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
+ (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ cached_val1: cached left context for the first attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_val2: cached left context for the second attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_conv1: cached left context for the first convolution module,
+ of shape (batch_size, channels, left_pad)
+ cached_conv2: cached left context for the second convolution module,
+ of shape (batch_size, channels, left_pad)
+ left_context_len: number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - x, with the same shape as src
+ - updated cached_key
+ - updated cached_nonlin_attn
+ - updated cached_val1
+ - updated cached_val2
+ - updated cached_conv1
+ - updated cached_conv2
+ """
+ src_orig = src
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights, cached_key = self.self_attn_weights.streaming_forward(
+ src,
+ pos_emb=pos_emb,
+ cached_key=cached_key,
+ left_context_len=left_context_len,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
+ src,
+ attn_weights[0:1],
+ cached_x=cached_nonlin_attn,
+ left_context_len=left_context_len,
+ )
+ src = src + na
+
+ self_attn, cached_val1 = self.self_attn1.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val1,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv1 = self.conv_module1.streaming_forward(
+ src,
+ cache=cached_conv1,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward2(src)
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn, cached_val2 = self.self_attn2.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val2,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv2 = self.conv_module2.streaming_forward(
+ src,
+ cache=cached_conv2,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward3(src)
+
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ return (
+ src,
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ )
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ pos_dim: int,
+ dropout: float,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ output = src
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ chunk_size=chunk_size,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ return output
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ pos_emb = self.encoder_pos(src, left_context_len)
+ output = src
+
+ new_states = []
+ for i, mod in enumerate(self.layers):
+ (
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ) = states[i * 6 : (i + 1) * 6]
+ (
+ output,
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ) = mod.streaming_forward(
+ output,
+ pos_emb,
+ cached_key=cached_key,
+ cached_nonlin_attn=cached_nonlin_attn,
+ cached_val1=cached_val1,
+ cached_val2=cached_val2,
+ cached_conv1=cached_conv1,
+ cached_conv2=cached_conv2,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ new_states += [
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ]
+
+ return output, new_states
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+ layer-skipping. The bypass is limited during early stages of training to be close to
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
+ force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 correponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
+ # on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+ after convolutional downsampling, and then upsampled again at the output, and combined
+ with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(
+ self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
+ ):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(dim, downsample, dropout)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(dim, downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+
+ src = self.encoder(
+ src,
+ chunk_size=chunk_size // ds,
+ feature_mask=feature_mask,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Downsample, go through encoder, upsample, in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
+ True means masked position. May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ src_orig = src
+ src = self.downsample(src)
+
+ src, new_states = self.encoder.streaming_forward(
+ src,
+ states=states,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src), new_states
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum, and a projection..
+ """
+
+ def __init__(self, channels: int, downsample: int, dropout: FloatLike):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+ self.dropout = copy.deepcopy(dropout)
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that mostly just repeats the input, but
+ also adds a position-specific bias.
+ """
+
+ def __init__(self, num_channels: int, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
+ the important information about the relative position in a relatively small number of dimensions.
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+ make very little difference to the embedding. Such differences were potentially important
+ when encoding absolute position, but not important when encoding relative position because there
+ is now no need to compare two large offsets with each other.
+
+ Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
+ using the atan() function, before doing the fourier transform of that fixed interval. The
+ atan() function would compress the "long tails" too small,
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+ function to compress large offsets to a smaller range before applying atan().
+ Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+ # for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+ # but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+ # is important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position encoding.
+ Various other modules consume the resulting attention weights: see, for example, the
+ SimpleAttention module which allows you to compute conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ lora_r: the bottleneck dimension of LoRA
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ # self.in_proj = ScaledLinear(
+ # embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+ # )
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=in_proj_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=query_head_dim**-0.25,
+ bias=True,
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be suffixient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnosics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ left_context_len: int,
+ key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ left_context_len: number of left context frames.
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+
+ Returns:
+ - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ - updated cached attention key tensor of left context.
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ # Pad cached left contexts
+ assert cached_key.shape[0] == left_context_len, (
+ cached_key.shape[0],
+ left_context_len,
+ )
+ k = torch.cat([cached_key, k], dim=0)
+ # Update cached left contexts
+ cached_key = k[-left_context_len:, ...]
+
+ # The length of key
+ k_len = k.shape[0]
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1 + left_context_len
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(k_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, k_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ assert attn_scores.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ k_len,
+ ), attn_scores.shape
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ attn_weights = attn_scores.softmax(dim=-1)
+
+ return attn_weights, cached_key
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.info(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed attention
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=num_heads * value_head_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ bias=True,
+ )
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_val: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ cached_val: cached attention value tensor of left context,
+ of shape (left_context_len, batch_size, value_dim)
+ left_context_len: number of left context frames.
+
+ Returns:
+ - attention weighted output, a tensor with the same shape as x.
+ - updated cached attention value tensor of left context.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ seq_len2 = seq_len + left_context_len
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+
+ # Pad cached left contexts
+ assert cached_val.shape[0] == left_context_len, (
+ cached_val.shape[0],
+ left_context_len,
+ )
+ x = torch.cat([cached_val, x], dim=0)
+ # Update cached left contexts
+ cached_val = x[-left_context_len:, ...]
+
+ x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+
+ return x, cached_val
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in Zipformer2 model."""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike,
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=feedforward_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ bias=True,
+ )
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear_lora(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
+ one after the attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
+ # before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_x: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ cached_x: left context, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ left_context_len: number of left context frames.
+ Returns:
+ - a Tensor with the same shape as x
+ - updated left context with same shape as cached_x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = x * s
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ left_context_len + seq_len,
+ )
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+
+ # Pad cached tensor
+ assert cached_x.shape[2] == left_context_len, (
+ cached_x.shape[2],
+ left_context_len,
+ )
+ x_pad = torch.cat([cached_x, x], dim=2)
+ # Update cached tensor
+ cached_x = x_pad[:, :, -left_context_len:, :]
+
+ x = torch.matmul(attn_weights, x_pad)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ x = x * y
+
+ x = self.out_proj(x)
+ return x, cached_x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ causal: bool,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+ self.causal = causal
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = (
+ ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
+ if causal
+ else nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ chunk_size: int = -1,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ if (
+ not torch.jit.is_scripting()
+ and not torch.jit.is_tracing()
+ and chunk_size >= 0
+ ):
+ # Not support exporting a model for simulated streaming decoding
+ assert (
+ self.causal
+ ), "Must initialize model with causal=True if you use chunk_size"
+ x = self.depthwise_conv(x, chunk_size=chunk_size)
+ else:
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module in streaming forward mode.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ cache: cached left context for depthwise_conv of shape
+ (#batch, channels, left_pad)
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ - Output tensor (#time, batch, channels).
+ - Updated cache (#batch, channels, left_pad)
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.sigmoid(s)
+ x = x * s
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
+
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x, cache
+
+
+class ScalarMultiply(nn.Module):
+ def __init__(self, scale: float):
+ super().__init__()
+ self.scale = scale
+
+ def forward(self, x):
+ return x * self.scale
+
+
+def _test_zipformer_main(causal: bool = False):
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+
+ c = Zipformer2(
+ encoder_dim=(64, 96),
+ encoder_unmasked_dim=(48, 64),
+ num_heads=(4, 4),
+ causal=causal,
+ chunk_size=(4,) if causal else (-1,),
+ left_context_frames=(64,),
+ )
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f[0].sum().backward()
+ c.eval()
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f # to remove flake8 warnings
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_zipformer_main(False)
+ _test_zipformer_main(True)
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
index dd8949523..c1785a328 100755
--- a/egs/librispeech/ASR/zipformer_mmi/train.py
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -79,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon, UniqLexicon
from icefall.mmi import LFMMILoss
@@ -816,9 +817,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/SSL/hubert/asr_datamodule.py b/egs/librispeech/SSL/hubert/asr_datamodule.py
new file mode 100644
index 000000000..3746d8a3a
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/asr_datamodule.py
@@ -0,0 +1,287 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from dataset import HubertAsrDataset
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriSpeechAsrDataModule:
+ """
+ DataModule for ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies.",
+ )
+ group.add_argument(
+ "--full-libri",
+ type=str2bool,
+ default=True,
+ help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/wav"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=float,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+ group.add_argument(
+ "--do-normalize",
+ type=str2bool,
+ default=True,
+ help="whether to normalize the data",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ do_normalize: bool,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = HubertAsrDataset(do_normalize=do_normalize)
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
+ logging.info("About to create dev dataset")
+ validate = HubertAsrDataset(do_normalize=do_normalize)
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet, do_normalize: bool) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = HubertAsrDataset(do_normalize=do_normalize)
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_clean_100_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-100 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_clean_360_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-360 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_other_500_cuts(self) -> CutSet:
+ logging.info("About to get train-other-500 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_all_shuf_cuts(self) -> CutSet:
+ logging.info(
+ "About to get the shuffled train-clean-100, \
+ train-clean-360 and train-other-500 cuts"
+ )
+ train_clean_100_cuts = self.train_clean_100_cuts()
+ train_clean_360_cuts = self.train_clean_360_cuts()
+ train_other_500_cuts = self.train_other_500_cuts()
+ return CutSet.mux(
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ weights=[
+ 28539, # len(train_clean_100_cuts)
+ 104014, # len(train_clean_360_cuts)
+ 148688, # len(train_other_500_cuts)
+ ],
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get test-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get test-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
+ )
diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py
new file mode 100644
index 000000000..39ef8698e
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/attention_module.py
@@ -0,0 +1,840 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import utils
+from torch import Tensor, nn
+from torch.nn import Parameter
+from utils import FairseqDropout, quant_noise
+
+_xformers_available = False
+
+
+# TODO: move this into xformers?
+# TODO: uint8 input type should just output a bool
+def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None):
+ """
+ call to pytorch multihead accepts three mask types:
+ - ByteTensor where non-zero means to mask
+ - FloatTensor which is an additive mask
+ - BoolTensor where True means to mask
+ xFormers currently accepts boolean and additive maks. For boolean masks
+ the values have opposite meaning. For a BoolTensor True mean to keep the value.
+ """
+ float_types = [torch.float, torch.float16]
+ # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool.
+ additive = mask.dtype in float_types
+ # If to_dype is not specified, keep same dtype as mask.
+ to_dtype = mask.dtype if to_dtype is None else to_dtype
+ to_additive = to_dtype in float_types
+
+ if additive:
+ if to_additive:
+ return mask.to(to_dtype)
+ mask = mask < 0
+
+ if to_additive:
+ # return additive mask
+ new_mask = torch.zeros_like(mask, dtype=to_dtype)
+ new_mask = new_mask.masked_fill_(mask, -float("inf"))
+ return new_mask
+
+ # In xFormers True is value to keep rather than value to mask
+ mask = ~mask.to(torch.bool)
+ mask = mask.to(to_dtype)
+ return mask
+
+
+def init_bert_params(module):
+ """
+ Initialize the weights specific to the BERT Model.
+ This overrides the default initializations depending on the specified arguments.
+ 1. If normal_init_linear_weights is set then weights of linear
+ layer will be initialized using the normal distribution and
+ bais will be set to the specified value.
+ 2. If normal_init_embed_weights is set then weights of embedding
+ layer will be initialized using the normal distribution.
+ 3. If normal_init_proj_weights is set then weights of
+ in_project_weight for MultiHeadAttention initialized using
+ the normal distribution (to be validated).
+ """
+
+ def normal_(data):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, MultiheadAttention):
+ normal_(module.q_proj.weight.data)
+ normal_(module.k_proj.weight.data)
+ normal_(module.v_proj.weight.data)
+
+
+class MultiheadAttention(nn.Module):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ dictionary=None,
+ q_noise=0.0,
+ qn_block_size=8,
+ # TODO: pass in config rather than string.
+ # config defined in xformers.components.attention.AttentionConfig
+ xformers_att_config: Optional[str] = None,
+ xformers_blocksparse_layout: Optional[
+ torch.Tensor
+ ] = None, # This should be part of the config
+ xformers_blocksparse_blocksize: Optional[
+ int
+ ] = 16, # This should be part of the config
+ ):
+ super().__init__()
+
+ self.use_xformers = False
+ if self.use_xformers and not _xformers_available:
+ raise ImportError("\n\n Please install xFormers.")
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout_module = FairseqDropout(
+ dropout, module_name=self.__class__.__name__
+ )
+
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert (
+ not self.self_attention or self.qkv_same_dim
+ ), "Self-attention requires query, key and value to be of the same size"
+
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+
+ self.out_proj = quant_noise(
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+ )
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+ self.beam_size = 1
+ self.reset_parameters()
+
+ self.onnx_trace = False
+ self.skip_embed_dim_check = False
+
+ def prepare_for_onnx_export_(self):
+ self.onnx_trace = True
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ # Empirically observed the convergence to be much better with
+ # the scaled initialization
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ else:
+ nn.init.xavier_uniform_(self.k_proj.weight)
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ nn.init.xavier_uniform_(self.q_proj.weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.out_proj.bias is not None:
+ nn.init.constant_(self.out_proj.bias, 0.0)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def _get_reserve_head_index(self, num_heads_to_keep: int):
+ k_proj_heads_norm = []
+ q_proj_heads_norm = []
+ v_proj_heads_norm = []
+
+ for i in range(self.num_heads):
+ start_idx = i * self.head_dim
+ end_idx = (i + 1) * self.head_dim
+ k_proj_heads_norm.append(
+ torch.sum(
+ torch.abs(
+ self.k_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+ ).tolist()
+ + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
+ )
+ q_proj_heads_norm.append(
+ torch.sum(
+ torch.abs(
+ self.q_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+ ).tolist()
+ + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
+ )
+ v_proj_heads_norm.append(
+ torch.sum(
+ torch.abs(
+ self.v_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+ ).tolist()
+ + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
+ )
+
+ heads_norm = []
+ for i in range(self.num_heads):
+ heads_norm.append(
+ k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
+ )
+
+ sorted_head_index = sorted(
+ range(self.num_heads), key=lambda k: heads_norm[k], reverse=True
+ )
+ reserve_head_index = []
+ for i in range(num_heads_to_keep):
+ start = sorted_head_index[i] * self.head_dim
+ end = (sorted_head_index[i] + 1) * self.head_dim
+ reserve_head_index.append((start, end))
+ return reserve_head_index
+
+ def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]):
+ new_q_weight = []
+ new_q_bias = []
+ new_k_weight = []
+ new_k_bias = []
+ new_v_weight = []
+ new_v_bias = []
+ new_out_proj_weight = []
+
+ for ele in reserve_head_index:
+ start_idx, end_idx = ele
+ new_q_weight.append(
+ self.q_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+ new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
+
+ new_k_weight.append(
+ self.k_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+
+ new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
+
+ new_v_weight.append(
+ self.v_proj.weight[
+ start_idx:end_idx,
+ ]
+ )
+ new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
+
+ new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
+
+ new_q_weight = torch.cat(new_q_weight).detach()
+ new_k_weight = torch.cat(new_k_weight).detach()
+ new_v_weight = torch.cat(new_v_weight).detach()
+ new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
+ new_q_weight.requires_grad = True
+ new_k_weight.requires_grad = True
+ new_v_weight.requires_grad = True
+ new_out_proj_weight.requires_grad = True
+
+ new_q_bias = torch.cat(new_q_bias).detach()
+ new_q_bias.requires_grad = True
+
+ new_k_bias = torch.cat(new_k_bias).detach()
+ new_k_bias.requires_grad = True
+
+ new_v_bias = torch.cat(new_v_bias).detach()
+ new_v_bias.requires_grad = True
+
+ self.q_proj.weight = torch.nn.Parameter(new_q_weight)
+ self.q_proj.bias = torch.nn.Parameter(new_q_bias)
+
+ self.k_proj.weight = torch.nn.Parameter(new_k_weight)
+ self.k_proj.bias = torch.nn.Parameter(new_k_bias)
+
+ self.v_proj.weight = torch.nn.Parameter(new_v_weight)
+ self.v_proj.bias = torch.nn.Parameter(new_v_bias)
+
+ self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight)
+
+ self.num_heads = len(reserve_head_index)
+ self.embed_dim = self.head_dim * self.num_heads
+ self.q_proj.out_features = self.embed_dim
+ self.k_proj.out_features = self.embed_dim
+ self.v_proj.out_features = self.embed_dim
+
+ def _set_skip_embed_dim_check(self):
+ self.skip_embed_dim_check = True
+
+ def _pad_masks(
+ self,
+ key_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor],
+ ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
+ if attn_mask is not None:
+ shape = attn_mask.size()[:-1] + torch.Size([1])
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
+ if key_padding_mask is not None:
+ shape = key_padding_mask.size()[:-1] + torch.Size([1])
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(shape),
+ ],
+ dim=-1,
+ )
+ return key_padding_mask, attn_mask
+
+ def _add_bias(
+ self,
+ k: Tensor,
+ v: Tensor,
+ key_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor],
+ bsz: int,
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
+ assert self.bias_k is not None
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ key_padding_mask, attn_mask = self._pad_masks(
+ key_padding_mask=key_padding_mask, attn_mask=attn_mask
+ )
+ return k, v, key_padding_mask, attn_mask
+
+ def _append_zero_attn(
+ self,
+ k: Tensor,
+ v: Tensor,
+ key_padding_mask: Optional[Tensor],
+ attn_mask: Optional[Tensor],
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
+ zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
+ k = torch.cat(
+ [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
+ dim=-2,
+ )
+ v = torch.cat(
+ [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
+ dim=-2,
+ )
+ key_padding_mask, attn_mask = self._pad_masks(
+ key_padding_mask=key_padding_mask, attn_mask=attn_mask
+ )
+ return k, v, key_padding_mask, attn_mask
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ if not self.skip_embed_dim_check:
+ assert (
+ embed_dim == self.embed_dim
+ ), f"query dim {embed_dim} != {self.embed_dim}"
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert value is not None
+ assert src_len, key_bsz == value.shape[:2]
+
+ if (
+ not self.onnx_trace
+ and not is_tpu # don't use PyTorch version on TPUs
+ and incremental_state is None
+ and not static_kv
+ # A workaround for quantization to work. Otherwise JIT compilation
+ # treats bias in linear module as method.
+ and not torch.jit.is_scripting()
+ # The Multihead attention implemented in pytorch forces strong dimension check
+ # for input embedding dimention and K,Q,V projection dimension.
+ # Since pruning will break the dimension check and it is not easy to modify the pytorch API,
+ # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check
+ and not self.skip_embed_dim_check
+ ):
+ assert key is not None and value is not None
+
+ return F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask.bool() if key_padding_mask is not None else None,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ if self.beam_size > 1 and bsz == key.size(1):
+ # key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
+ key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
+ :, :, 0, :
+ ]
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.view(
+ -1, self.beam_size, key_padding_mask.size(1)
+ )[:, 0, :]
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k, v, attn_mask, key_padding_mask = self._add_bias(
+ k, v, attn_mask, key_padding_mask, bsz
+ )
+
+ q = (
+ q.contiguous()
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ kv_bsz = bsz # need default value for scripting
+ if k is not None:
+ kv_bsz = k.size(1)
+ k = (
+ k.contiguous()
+ .view(-1, kv_bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+ if v is not None:
+ v = (
+ v.contiguous()
+ .view(-1, kv_bsz * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ kv_bsz = _prev_key.size(0)
+ prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ assert kv_bsz == _prev_value.size(0)
+ prev_value = _prev_value.view(
+ kv_bsz * self.num_heads, -1, self.head_dim
+ )
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=kv_bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(
+ kv_bsz, self.num_heads, -1, self.head_dim
+ )
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == kv_bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k, v, key_padding_mask, attn_mask = self._append_zero_attn(
+ k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask
+ )
+
+ if self.encoder_decoder_attention and bsz != kv_bsz:
+ attn_weights = torch.einsum(
+ "bxhtd,bhsd->bxhts",
+ q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
+ k.view((kv_bsz, self.num_heads) + k.size()[1:]),
+ )
+ attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
+ else:
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [
+ bsz * self.num_heads,
+ tgt_len,
+ src_len,
+ ]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ if self.onnx_trace:
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.view(
+ kv_bsz, -1, self.num_heads, tgt_len, src_len
+ )
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = utils.softmax(
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
+ )
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn: Optional[Tensor] = None
+ if self.encoder_decoder_attention and bsz != kv_bsz:
+ attn = torch.einsum(
+ "bxhts,bhsd->bxhtd",
+ attn_probs.view(
+ (
+ kv_bsz,
+ -1,
+ self.num_heads,
+ )
+ + attn_probs.size()[1:]
+ ),
+ v.view(
+ (
+ kv_bsz,
+ self.num_heads,
+ )
+ + v.size()[1:]
+ ),
+ )
+ attn = attn.reshape((-1,) + attn.size()[-2:])
+ else:
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [
+ bsz * self.num_heads,
+ tgt_len,
+ self.head_dim,
+ ]
+ if self.onnx_trace and attn.size(1) == 1:
+ # when ONNX tracing a single decoder step (sequence length == 1)
+ # the transpose is a no-op copy before view, thus unnecessary
+ attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
+ else:
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(
+ bsz, self.num_heads, tgt_len, src_len
+ ).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
+
+ @staticmethod
+ def _append_prev_key_padding_mask(
+ key_padding_mask: Optional[Tensor],
+ prev_key_padding_mask: Optional[Tensor],
+ batch_size: int,
+ src_len: int,
+ static_kv: bool,
+ ) -> Optional[Tensor]:
+ # saved key padding masks have shape (bsz, seq_len)
+ if prev_key_padding_mask is not None and static_kv:
+ new_key_padding_mask = prev_key_padding_mask
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
+ new_key_padding_mask = torch.cat(
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
+ )
+ # During incremental decoding, as the padding token enters and
+ # leaves the frame, there will be a time when prev or current
+ # is None
+ elif prev_key_padding_mask is not None:
+ if src_len > prev_key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
+ device=prev_key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat(
+ [prev_key_padding_mask.float(), filler.float()], dim=1
+ )
+ else:
+ new_key_padding_mask = prev_key_padding_mask.float()
+ elif key_padding_mask is not None:
+ if src_len > key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - key_padding_mask.size(1)),
+ device=key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat(
+ [filler.float(), key_padding_mask.float()], dim=1
+ )
+ else:
+ new_key_padding_mask = key_padding_mask.float()
+ else:
+ new_key_padding_mask = prev_key_padding_mask
+ return new_key_padding_mask
+
+ @torch.jit.export
+ def reorder_incremental_state(
+ self,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
+ new_order: Tensor,
+ ):
+ """Reorder buffered internal state (for incremental generation)."""
+ input_buffer = self._get_input_buffer(incremental_state)
+ if input_buffer is not None:
+ for k in input_buffer.keys():
+ input_buffer_k = input_buffer[k]
+ if input_buffer_k is not None:
+ if self.encoder_decoder_attention:
+ if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
+ return incremental_state
+ elif self.beam_size > 1:
+ input_buffer[k] = input_buffer_k.index_select(
+ 0,
+ new_order.reshape(-1, self.beam_size)[:, 0]
+ // self.beam_size,
+ )
+ else:
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
+ else:
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
+ return incremental_state
+
+ def set_beam_size(self, beam_size):
+ """Used for effiecient beamable enc-dec attention"""
+ self.beam_size = beam_size
+
+ def _get_input_buffer(
+ self,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
+ ) -> Dict[str, Optional[Tensor]]:
+ result = self.get_incremental_state(incremental_state, "attn_state")
+ if result is not None:
+ return result
+ else:
+ empty_result: Dict[str, Optional[Tensor]] = {}
+ return empty_result
+
+ def _set_input_buffer(
+ self,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
+ buffer: Dict[str, Optional[Tensor]],
+ ):
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+ return attn_weights
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ prefix = name + "." if name != "" else ""
+ items_to_add = {}
+ keys_to_remove = []
+ for k in state_dict.keys():
+ if k.endswith(prefix + "in_proj_weight"):
+ # in_proj_weight used to be q + k + v with same dimensions
+ dim = int(state_dict[k].shape[0] / 3)
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
+
+ keys_to_remove.append(k)
+
+ k_bias = prefix + "in_proj_bias"
+ if k_bias in state_dict.keys():
+ dim = int(state_dict[k].shape[0] / 3)
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
+ dim : 2 * dim
+ ]
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
+
+ keys_to_remove.append(prefix + "in_proj_bias")
+
+ for k in keys_to_remove:
+ del state_dict[k]
+
+ for key, value in items_to_add.items():
+ state_dict[key] = value
diff --git a/egs/librispeech/SSL/hubert/beam_search.py b/egs/librispeech/SSL/hubert/beam_search.py
new file mode 120000
index 000000000..f4d4b5732
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/beam_search.py
@@ -0,0 +1 @@
+../../ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/hubert/dataset.py b/egs/librispeech/SSL/hubert/dataset.py
new file mode 100644
index 000000000..76edfb340
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/dataset.py
@@ -0,0 +1,367 @@
+# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from lhotse import validate
+from lhotse.cut import CutSet
+from lhotse.dataset.collation import read_audio_from_cuts
+from torch.utils.data.dataloader import default_collate
+
+
+class HubertDataset(torch.utils.data.Dataset):
+ """
+ In this implementation, there will always be a single channel.
+
+ Returns:
+
+ .. code-block::
+
+ {
+ 'audio': (B x NumSamples) float tensor
+ }
+ """
+
+ def __init__(
+ self,
+ max_sample_size: Optional[int] = None,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ ) -> None:
+ super().__init__()
+ self.sample_rate = sample_rate
+ self.label_rate = label_rate
+ self.random_crop = random_crop
+ self.pad_audio = pad_audio
+ self.num_classes = num_classes
+ self.normalize = do_normalize
+ self.max_sample_size = (
+ max_sample_size if max_sample_size is not None else sys.maxsize
+ )
+
+ def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
+ self._validate(cuts)
+ audio, _ = read_audio_from_cuts(cuts)
+ for i, item in enumerate(audio):
+ audio[i] = self.postprocess(item, self.sample_rate)
+ audio_lens = [cut.num_samples for cut in cuts]
+
+ if self.pad_audio:
+ audio_size = min(max(audio_lens), self.max_sample_size)
+ else:
+ audio_size = min(min(audio_lens), self.max_sample_size)
+
+ audio, padding_mask, audio_starts = self.collater_audio(
+ audio, audio_lens, audio_size
+ )
+
+ kmeans = [cut.custom["kmeans"] for cut in cuts]
+ kmeans = [
+ torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
+ for label in kmeans
+ ]
+ kmeans, _ = self.collater_frm_label(kmeans, audio_size, audio_starts)
+
+ return {
+ "cuts": cuts,
+ "audio": audio,
+ "padding_mask": padding_mask,
+ "kmeans": kmeans,
+ }
+
+ def postprocess(self, wav, cur_sample_rate):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if cur_sample_rate != self.sample_rate:
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
+
+ def _validate(self, cuts: CutSet) -> None:
+ validate(cuts)
+ assert all(cut.has_recording for cut in cuts)
+
+ def crop_to_max_size(self, wav, target_size):
+ size = len(wav)
+ diff = size - target_size
+ if diff <= 0:
+ return wav, 0
+
+ start, end = 0, target_size
+ if self.random_crop:
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return wav[start:end], start
+
+ def collater_audio(self, audios, audio_lens, audio_size):
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
+ padding_mask = (
+ torch.BoolTensor(collated_audios.shape).fill_(False)
+ # if self.pad_audio else None
+ )
+ audio_starts = [0 for _ in audios]
+ for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
+ audio = audio[:audio_len]
+ diff = audio_len - audio_size
+ if diff == 0:
+ collated_audios[i] = audio
+ elif diff < 0:
+ assert self.pad_audio
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
+ padding_mask[i, diff:] = True
+ else:
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
+ audio, audio_size
+ )
+ return collated_audios, padding_mask, audio_starts
+
+ def collate_tokens(
+ self,
+ values,
+ pad_idx,
+ eos_idx=None,
+ left_pad=False,
+ move_eos_to_beginning=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+ pad_to_bsz=None,
+ ):
+ """Convert a list of 1d tensors into a padded 2d tensor."""
+ size = max(v.size(0) for v in values)
+ size = size if pad_to_length is None else max(size, pad_to_length)
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
+
+ batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
+ res = values[0].new(batch_size, size).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if move_eos_to_beginning:
+ if eos_idx is None:
+ # if no eos_idx is specified, then use the last token in src
+ dst[0] = src[-1]
+ else:
+ dst[0] = eos_idx
+ dst[1:] = src[:-1]
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
+ return res
+
+ def collater_frm_label(self, targets, audio_size, audio_starts):
+ label_rate = self.label_rate
+ pad = self.num_classes[0] - 1
+ assert label_rate > 0
+ s2f = label_rate / self.sample_rate
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
+ frm_size = int(round(audio_size * s2f))
+ if not self.pad_audio:
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
+ frm_size = min(frm_size, *rem_size)
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
+
+ lengths = torch.LongTensor([len(t) for t in targets])
+ targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
+ return targets, lengths
+
+
+class HubertAsrDataset(torch.utils.data.Dataset):
+ """
+ In this implementation, there will always be a single channel.
+
+ Returns:
+
+ .. code-block::
+
+ {
+ 'audio': (B x NumSamples) float tensor
+ }
+ """
+
+ def __init__(
+ self,
+ max_sample_size: Optional[int] = None,
+ sample_rate: float = 16000,
+ random_crop: bool = True,
+ pad_audio: bool = True,
+ do_normalize: bool = True,
+ ) -> None:
+ super().__init__()
+ self.sample_rate = sample_rate
+ self.random_crop = random_crop
+ self.pad_audio = pad_audio
+ self.normalize = do_normalize
+ self.max_sample_size = (
+ max_sample_size if max_sample_size is not None else sys.maxsize
+ )
+
+ def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
+ self._validate(cuts)
+ audio, _ = read_audio_from_cuts(cuts)
+ for i, item in enumerate(audio):
+ audio[i] = self.postprocess(item, self.sample_rate)
+ audio_lens = [cut.num_samples for cut in cuts]
+ if self.pad_audio:
+ audio_size = min(max(audio_lens), self.max_sample_size)
+ else:
+ audio_size = min(min(audio_lens), self.max_sample_size)
+
+ audio, padding_mask, audio_starts = self.collater_audio(
+ audio, audio_lens, audio_size
+ )
+
+ return {
+ "cuts": cuts,
+ "audio": audio,
+ "padding_mask": padding_mask,
+ "supervisions": default_collate(
+ [
+ {
+ "text": supervision.text,
+ }
+ for sequence_idx, cut in enumerate(cuts)
+ for supervision in cut.supervisions
+ ]
+ ),
+ }
+
+ def postprocess(self, wav, cur_sample_rate):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if cur_sample_rate != self.sample_rate:
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
+
+ def _validate(self, cuts: CutSet) -> None:
+ validate(cuts)
+ assert all(cut.has_recording for cut in cuts)
+
+ def crop_to_max_size(self, wav, target_size):
+ size = len(wav)
+ diff = size - target_size
+ if diff <= 0:
+ return wav, 0
+
+ start, end = 0, target_size
+ if self.random_crop:
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return wav[start:end], start
+
+ def collater_audio(self, audios, audio_lens, audio_size):
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
+ padding_mask = (
+ torch.BoolTensor(collated_audios.shape).fill_(False)
+ # if self.pad_audio else None
+ )
+ audio_starts = [0 for _ in audios]
+ for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
+ audio = audio[:audio_len]
+ diff = audio_len - audio_size
+ if diff == 0:
+ collated_audios[i] = audio
+ elif diff < 0:
+ assert self.pad_audio
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
+ padding_mask[i, diff:] = True
+ else:
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
+ audio, audio_size
+ )
+ return collated_audios, padding_mask, audio_starts
+
+ def collate_tokens(
+ self,
+ values,
+ pad_idx,
+ eos_idx=None,
+ left_pad=False,
+ move_eos_to_beginning=False,
+ pad_to_length=None,
+ pad_to_multiple=1,
+ pad_to_bsz=None,
+ ):
+ """Convert a list of 1d tensors into a padded 2d tensor."""
+ size = max(v.size(0) for v in values)
+ size = size if pad_to_length is None else max(size, pad_to_length)
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
+
+ batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
+ res = values[0].new(batch_size, size).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if move_eos_to_beginning:
+ if eos_idx is None:
+ # if no eos_idx is specified, then use the last token in src
+ dst[0] = src[-1]
+ else:
+ dst[0] = eos_idx
+ dst[1:] = src[:-1]
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
+ return res
+
+
+if __name__ == "__main__":
+ from lhotse import load_manifest_lazy
+ from lhotse.dataset import DynamicBucketingSampler
+ from torch.utils.data import DataLoader
+
+ dataset = HubertDataset()
+ cuts = load_manifest_lazy("data/fbank2/librispeech_cuts_train-clean-100.jsonl.gz")
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=100,
+ shuffle=False,
+ )
+ dl = DataLoader(
+ dataset,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=2,
+ )
+
+ for batch_idx, batch in enumerate(dl):
+ print(batch)
+ break
diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py
new file mode 100644
index 000000000..837061b8c
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/decode.py
@@ -0,0 +1,1045 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune import add_model_arguments, get_model, get_params
+from hubert import add_hubert_arguments
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ add_hubert_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(batch["supervisions"]["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["cuts"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ dev_clean_cuts = librispeech.dev_clean_cuts()
+ dev_other_cuts = librispeech.dev_other_cuts()
+
+ dev_clean_dl = librispeech.test_dataloaders(
+ dev_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ dev_other_dl = librispeech.test_dataloaders(
+ dev_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(
+ test_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ test_other_dl = librispeech.test_dataloaders(
+ test_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
+ test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/decode_ce.py b/egs/librispeech/SSL/hubert/decode_ce.py
new file mode 100644
index 000000000..a8d8bc9c2
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/decode_ce.py
@@ -0,0 +1,1045 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./hubert/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./hubert/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune_ce import add_model_arguments, get_model, get_params
+from hubert_ce import add_hubert_arguments
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ add_hubert_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(batch["supervisions"]["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["cuts"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ dev_clean_cuts = librispeech.dev_clean_cuts()
+ dev_other_cuts = librispeech.dev_other_cuts()
+
+ dev_clean_dl = librispeech.test_dataloaders(
+ dev_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ dev_other_dl = librispeech.test_dataloaders(
+ dev_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(
+ test_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ test_other_dl = librispeech.test_dataloaders(
+ test_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
+ test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/decoder.py b/egs/librispeech/SSL/hubert/decoder.py
new file mode 120000
index 000000000..a2138e5da
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/decoder.py
@@ -0,0 +1 @@
+../../ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py
new file mode 100644
index 000000000..201847aed
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/finetune.py
@@ -0,0 +1,1254 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For HuBERT model finetuning:
+./hubert/finetune.py \
+ --world-size 8 \
+ --num-epochs 200 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 0 \
+ --max-duration 200
+
+It supports finetuning with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from hubert import HubertModel, add_hubert_arguments
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=222,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--pretrained-dir",
+ type=str,
+ help="""The pretrained model dir.
+ It specifies the directory where the pretrained checkpoint is saved.""",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=1,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_hubert_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for pruned RNN-T loss
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ if hasattr(params, "pretrained_dir"):
+ logging.info(f"Loading {params.pretrained_dir}")
+ pretrained = torch.load(params.pretrained_dir)
+ encoder = HubertModel(params)
+ encoder.load_state_dict(pretrained["model"])
+ else:
+ encoder = HubertModel(params)
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_embed_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_embed_dim,
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss, num_frames = model(
+ x=audio,
+ padding_mask=padding_mask,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_frames.sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+ y = sp.encode(batch["supervisions"]["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py
new file mode 100644
index 000000000..e69a5a8cd
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/finetune_ce.py
@@ -0,0 +1,1254 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For HuBERT model finetuning:
+./hubert/finetune.py \
+ --world-size 8 \
+ --num-epochs 200 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 0 \
+ --max-duration 200
+
+It supports finetuning with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from hubert_ce import HubertModel, add_hubert_arguments
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=222,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--pretrained-dir",
+ type=str,
+ help="""The pretrained model dir.
+ It specifies the directory where the pretrained checkpoint is saved.""",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=1,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_hubert_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for pruned RNN-T loss
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ if hasattr(params, "pretrained_dir"):
+ logging.info(f"Loading {params.pretrained_dir}")
+ pretrained = torch.load(params.pretrained_dir)
+ encoder = HubertModel(params)
+ encoder.load_state_dict(pretrained["model"])
+ else:
+ encoder = HubertModel(params)
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=params.encoder_embed_dim,
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=params.encoder_embed_dim,
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss, num_frames = model(
+ x=audio,
+ padding_mask=padding_mask,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_frames.sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+ y = sp.encode(batch["supervisions"]["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/hubert.py b/egs/librispeech/SSL/hubert/hubert.py
new file mode 100644
index 000000000..f800044f4
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/hubert.py
@@ -0,0 +1,984 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import argparse
+import logging
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils import GradMultiply, LayerNorm
+from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
+ add_masks: bool = False,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+ mask_dropout: randomly dropout this percentage of masks in each example
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ if num_mask_ver == 1:
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ if padding_mask is not None:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ num_mask = all_num_mask
+ elif num_mask_ver == 2:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ raise ValueError()
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = rng.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ if mask_type == "static":
+ raise ValueError(f"this should never happens")
+ else:
+ lengths = [min(mask_length, sz - 1)]
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = rng.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError()
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError(
+ (
+ f"the entire sequence is masked. "
+ f"sz={sz}; mask_idc[mask_idc]; "
+ f"index={indices[i] if indices is not None else None}"
+ )
+ )
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ if add_masks:
+ target_len = max([len(m) for m in mask_idcs])
+ else:
+ target_len = min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
+
+ mask[i, mask_idc] = True
+
+ if target_len is not None and len(mask_idc) < target_len:
+ unmasked = np.flatnonzero(~mask[i])
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
+ to_drop = rng.choice(masked, num_holes, replace=False)
+ mask[i, to_drop] = False
+
+ return mask
+
+
+def add_hubert_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+ parser.add_argument(
+ "--encoder-layers",
+ type=int,
+ default=12,
+ help="num encoder layers in the transformer",
+ )
+
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ default=768,
+ help="encoder embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ default=3072,
+ help="encoder embedding dimension for FFN",
+ )
+
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ default=12,
+ help="num encoder attention heads",
+ )
+
+ parser.add_argument(
+ "--activation-fn",
+ type=str,
+ choices=[
+ "relu",
+ "gelu",
+ "gelu_fast",
+ "gelu_accurate",
+ "tanh",
+ "linear",
+ ],
+ default="gelu",
+ help="activation function to use",
+ )
+
+ parser.add_argument(
+ "--layer-type",
+ type=str,
+ choices=["transformer", "conformer", "trf_adp"],
+ default="transformer",
+ help="layer type in encoder",
+ )
+
+ # dropouts
+ parser.add_argument(
+ "--dropout",
+ type=float,
+ default=0.1,
+ help="dropout probability for the transformer",
+ )
+
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ default=0.1,
+ help="dropout probability for attention weights",
+ )
+
+ parser.add_argument(
+ "--activation-dropout",
+ type=float,
+ default=0.0,
+ help="dropout probability after activation in FFN",
+ )
+
+ parser.add_argument(
+ "--encoder-layerdrop",
+ type=float,
+ default=0.0,
+ help="probability of dropping a tarnsformer layer",
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--final-dim",
+ type=int,
+ default=0,
+ help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+ parser.add_argument(
+ "--layer-norm-first",
+ type=bool,
+ default=False,
+ help="apply layernorm first in the transformer",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias",
+ type=bool,
+ default=False,
+ help="include bias in conv encoder",
+ )
+
+ parser.add_argument(
+ "--logit-temp",
+ type=float,
+ default=0.1,
+ help="temperature to divide logits by",
+ )
+
+ parser.add_argument(
+ "--target-glu",
+ type=bool,
+ default=False,
+ help="adds projection + glu to targets",
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # positional embeddings
+ parser.add_argument(
+ "--conv-pos",
+ type=int,
+ default=128,
+ help="number of filters for convolutional positional embeddings",
+ )
+
+ parser.add_argument(
+ "--conv-pos-groups",
+ type=int,
+ default=16,
+ help="number of groups for convolutional positional embedding",
+ )
+
+ parser.add_argument(
+ "--conv-pos-batch-norm",
+ type=bool,
+ default=False,
+ help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
+ )
+
+ parser.add_argument(
+ "--latent-temp",
+ type=float,
+ nargs="*",
+ default=[2, 0.5, 0.999995],
+ help="legacy (to be removed)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+
+class HubertModel(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ ) -> None:
+ super().__init__()
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ self.target_glu = None
+ if cfg.target_glu:
+ self.target_glu = nn.Sequential(
+ nn.Linear(final_dim, final_dim * 2), nn.GLU()
+ )
+
+ self.untie_final_proj = cfg.untie_final_proj
+ if self.untie_final_proj:
+ self.final_proj = nn.Linear(
+ cfg.encoder_embed_dim, final_dim * len(cfg.num_classes)
+ )
+ else:
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
+
+ # modules below are not needed during fine-tuning
+ self.num_classes = cfg.num_classes
+ self.label_embs_concat = nn.Parameter(
+ torch.FloatTensor(sum(self.num_classes), final_dim)
+ )
+ self.pred_masked_weight = cfg.pred_masked_weight
+ self.pred_nomask_weight = cfg.pred_nomask_weight
+ self.loss_weights = cfg.loss_weights
+ nn.init.uniform_(self.label_embs_concat)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb.to(x.dtype)
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def compute_nce(self, x, pos, negs):
+ neg_is_pos = (pos == negs).all(-1)
+ pos = pos.unsqueeze(0)
+ targets = torch.cat([pos, negs], dim=0)
+
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
+ logits /= self.logit_temp
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
+ return logits
+
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(
+ self,
+ features: torch.Tensor,
+ target_list: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Trim features to ensure labels exist and then get aligned labels
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+ target_list = [t[:, target_inds.long()] for t in target_list]
+ return features, target_list
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def forward(
+ self,
+ source: torch.Tensor,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ):
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, _ = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ def compute_pred(proj_x, target, label_embs):
+ # compute logits for the i-th label set
+ y = torch.index_select(label_embs, 0, target.long())
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
+ if self.target_glu:
+ y = self.target_glu(y)
+ negs = self.target_glu(negs)
+ # proj_x: (S, D)
+ # y: (S, D)
+ # negs: (Neg, S, D)
+ return self.compute_nce(proj_x, y, negs)
+
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
+
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = self.final_proj(x[masked_indices])
+ if self.untie_final_proj:
+ proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
+ else:
+ proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
+ logit_m_list = [
+ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
+ for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
+ ]
+ else:
+ logit_m_list = [None for _ in target_list]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = self.final_proj(x[nomask_indices])
+ if self.untie_final_proj:
+ proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
+ else:
+ proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
+
+ logit_u_list = [
+ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
+ for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
+ ]
+ else:
+ logit_u_list = [None for _ in target_list]
+
+ # result = {
+ # "logit_m_list": logit_m_list,
+ # "logit_u_list": logit_u_list,
+ # "padding_mask": padding_mask,
+ # "features_pen": features_pen,
+ # }
+ return self.compute_loss(logit_m_list, logit_u_list, features_pen)
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+ feature = res["features"] if ret_conv else res["x"]
+ return feature, res["padding_mask"]
+
+ def get_logits(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ logits_list = [x.float() for x in logits_list if x is not None]
+ return logits_list
+
+ def get_targets(self, net_output, is_masked=True):
+ logits_list = self.get_logits(net_output, is_masked)
+ targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
+ return targets_list
+
+ def get_extra_losses(self, net_output):
+ extra_losses = []
+ names = []
+
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self):
+ self.target_glu = None
+ self.final_proj = None
+
+ def compute_loss(self, logit_m_list, logit_u_list, features_pen):
+ loss = 0.0
+ sample_size = 0
+ logging_output = {}
+ reduce = True
+ reduction = "sum" if reduce else "none"
+
+ loss_m_list = []
+ logp_m_list = [x.float() for x in logit_m_list if x is not None]
+ targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list]
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
+ loss_m_list.append(loss_m)
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
+ if self.pred_masked_weight > 0:
+ loss += self.pred_masked_weight * sum(loss_m_list)
+ sample_size += targ_m_list[0].numel()
+
+ loss_u_list = []
+ logp_u_list = [x.float() for x in logit_u_list if x is not None]
+ targ_u_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_u_list]
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
+ loss_u_list.append(loss_u)
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
+ if self.pred_nomask_weight > 0:
+ loss += self.pred_nomask_weight * sum(loss_u_list)
+ sample_size += targ_u_list[0].numel()
+
+ if self.loss_weights is not None:
+ extra_losses = []
+ names = []
+ extra_losses.append(features_pen)
+ names.append("features_pen")
+ if torch.is_tensor(extra_losses):
+ extra_losses = [extra_losses]
+ names = [names]
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
+ if coef != 0 and p is not None:
+ p = coef * p.float() * sample_size
+ loss += p
+ logging_output[f"loss_{n}"] = p.item()
+
+ logging_output = {
+ "loss": loss.item() if reduce else loss,
+ **logging_output,
+ }
+
+ # for lk in self.log_keys:
+ # if lk in net_output:
+ # logging_output[lk] = float((net_output[lk]))
+
+ def compute_correct(logits):
+ if logits.numel() == 0:
+ return 0, 0
+ else:
+ assert logits.dim() > 1, logits.shape
+ max = logits.argmax(-1) == 0
+ min = logits.argmin(-1) == 0
+ both = max & min
+ corr = max.long().sum().item() - both.long().sum().item()
+ count = max.numel()
+ return corr, count
+
+ with torch.no_grad():
+ for i, logp_m in enumerate(logp_m_list):
+ corr_m, count_m = compute_correct(logp_m)
+ logging_output[f"correct_m_{i}"] = corr_m
+ logging_output[f"count_m_{i}"] = count_m
+
+ for i, logp_u in enumerate(logp_u_list):
+ corr_u, count_u = compute_correct(logp_u)
+ logging_output[f"correct_u_{i}"] = corr_u
+ logging_output[f"count_u_{i}"] = count_u
+
+ return loss, sample_size, logging_output
diff --git a/egs/librispeech/SSL/hubert/hubert_ce.py b/egs/librispeech/SSL/hubert/hubert_ce.py
new file mode 100644
index 000000000..ccdd63efd
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/hubert_ce.py
@@ -0,0 +1,940 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import argparse
+import logging
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils import GradMultiply, LayerNorm
+from wav2vec2_module import ConvFeatureExtractionModel, TransformerEncoder
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
+ add_masks: bool = False,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+ mask_dropout: randomly dropout this percentage of masks in each example
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ if num_mask_ver == 1:
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ if padding_mask is not None:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ num_mask = all_num_mask
+ elif num_mask_ver == 2:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ raise ValueError()
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = rng.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ if mask_type == "static":
+ raise ValueError(f"this should never happens")
+ else:
+ lengths = [min(mask_length, sz - 1)]
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = rng.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError()
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError(
+ (
+ f"the entire sequence is masked. "
+ f"sz={sz}; mask_idc[mask_idc]; "
+ f"index={indices[i] if indices is not None else None}"
+ )
+ )
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ if add_masks:
+ target_len = max([len(m) for m in mask_idcs])
+ else:
+ target_len = min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
+
+ mask[i, mask_idc] = True
+
+ if target_len is not None and len(mask_idc) < target_len:
+ unmasked = np.flatnonzero(~mask[i])
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
+ to_drop = rng.choice(masked, num_holes, replace=False)
+ mask[i, to_drop] = False
+
+ return mask
+
+
+def add_hubert_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+ parser.add_argument(
+ "--encoder-layers",
+ type=int,
+ default=12,
+ help="num encoder layers in the transformer",
+ )
+
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ default=768,
+ help="encoder embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-ffn-embed-dim",
+ type=int,
+ default=3072,
+ help="encoder embedding dimension for FFN",
+ )
+
+ parser.add_argument(
+ "--encoder-attention-heads",
+ type=int,
+ default=12,
+ help="num encoder attention heads",
+ )
+
+ parser.add_argument(
+ "--activation-fn",
+ type=str,
+ choices=[
+ "relu",
+ "gelu",
+ "gelu_fast",
+ "gelu_accurate",
+ "tanh",
+ "linear",
+ ],
+ default="gelu",
+ help="activation function to use",
+ )
+
+ parser.add_argument(
+ "--layer-type",
+ type=str,
+ choices=["transformer", "conformer", "trf_adp"],
+ default="transformer",
+ help="layer type in encoder",
+ )
+
+ # dropouts
+ parser.add_argument(
+ "--dropout",
+ type=float,
+ default=0.1,
+ help="dropout probability for the transformer",
+ )
+
+ parser.add_argument(
+ "--attention-dropout",
+ type=float,
+ default=0.1,
+ help="dropout probability for attention weights",
+ )
+
+ parser.add_argument(
+ "--activation-dropout",
+ type=float,
+ default=0.0,
+ help="dropout probability after activation in FFN",
+ )
+
+ parser.add_argument(
+ "--encoder-layerdrop",
+ type=float,
+ default=0.0,
+ help="probability of dropping a tarnsformer layer",
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--final-dim",
+ type=int,
+ default=0,
+ help="project final representations and targets to this many dimensions. set to encoder_embed_dim is <= 0",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+ parser.add_argument(
+ "--layer-norm-first",
+ type=bool,
+ default=False,
+ help="apply layernorm first in the transformer",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias",
+ type=bool,
+ default=False,
+ help="include bias in conv encoder",
+ )
+
+ parser.add_argument(
+ "--logit-temp",
+ type=float,
+ default=0.1,
+ help="temperature to divide logits by",
+ )
+
+ parser.add_argument(
+ "--target-glu",
+ type=bool,
+ default=False,
+ help="adds projection + glu to targets",
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # positional embeddings
+ parser.add_argument(
+ "--conv-pos",
+ type=int,
+ default=128,
+ help="number of filters for convolutional positional embeddings",
+ )
+
+ parser.add_argument(
+ "--conv-pos-groups",
+ type=int,
+ default=16,
+ help="number of groups for convolutional positional embedding",
+ )
+
+ parser.add_argument(
+ "--conv-pos-batch-norm",
+ type=bool,
+ default=False,
+ help="use batch norm instead of weight norm in conv_pos (for bf16 models)",
+ )
+
+ parser.add_argument(
+ "--latent-temp",
+ type=float,
+ nargs="*",
+ default=[2, 0.5, 0.999995],
+ help="legacy (to be removed)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+
+class HubertModel(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ ) -> None:
+ super().__init__()
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ self.untie_final_proj = cfg.untie_final_proj
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, sum(cfg.num_classes))
+
+ # modules below are not needed during fine-tuning
+ self.num_classes = cfg.num_classes
+ self.pred_masked_weight = cfg.pred_masked_weight
+ self.pred_nomask_weight = cfg.pred_nomask_weight
+ self.loss_weights = cfg.loss_weights
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb.to(x.dtype)
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(
+ self,
+ features: torch.Tensor,
+ target_list: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Trim features to ensure labels exist and then get aligned labels
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+ target_list = [t[:, target_inds.long()] for t in target_list]
+ return features, target_list
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def forward(
+ self,
+ source: torch.Tensor,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ):
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, _ = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = self.final_proj(x[masked_indices])
+ proj_x_m /= self.logit_temp
+ logit_m_list = [proj_x_m for _ in range(len(target_list))]
+ else:
+ logit_m_list = [None for _ in target_list]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = self.final_proj(x[nomask_indices])
+ proj_x_u /= self.logit_temp
+ logit_u_list = [proj_x_u for _ in range(len(target_list))]
+ else:
+ logit_u_list = [None for _ in target_list]
+
+ # result = {
+ # "logit_m_list": logit_m_list,
+ # "logit_u_list": logit_u_list,
+ # "padding_mask": padding_mask,
+ # "features_pen": features_pen,
+ # }
+ targ_m_list = target_list[0][masked_indices]
+ targ_m_list = targ_m_list.long()
+ targ_m_list = [targ_m_list for _ in range(len(target_list))]
+
+ targ_u_list = target_list[0][nomask_indices]
+ targ_u_list = targ_u_list.long()
+ targ_u_list = [targ_u_list for _ in range(len(target_list))]
+ return self.compute_loss(
+ logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
+ )
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+ feature = res["features"] if ret_conv else res["x"]
+ return feature, res["padding_mask"]
+
+ def get_logits(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ logits_list = [x.float() for x in logits_list if x is not None]
+ return logits_list
+
+ def get_targets(self, net_output, is_masked=True):
+ logits_list = self.get_logits(net_output, is_masked)
+ targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
+ return targets_list
+
+ def get_extra_losses(self, net_output):
+ extra_losses = []
+ names = []
+
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self):
+ self.final_proj = None
+
+ def compute_loss(
+ self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
+ ):
+ loss = 0.0
+ sample_size = 0
+ logging_output = {}
+ reduce = True
+ reduction = "sum" if reduce else "none"
+
+ loss_m_list = []
+ logp_m_list = [x.float() for x in logit_m_list if x is not None]
+ logp_m_list = torch.cat(logp_m_list)
+ targ_m_list = torch.cat(targ_m_list)
+
+ loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
+ loss_m_list.append(loss_m)
+ logging_output[f"loss_m_0"] = loss_m.detach().item()
+
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
+ if self.pred_masked_weight > 0:
+ loss += self.pred_masked_weight * sum(loss_m_list)
+ sample_size += len(targ_m_list)
+
+ loss_u_list = []
+ logp_u_list = [x.float() for x in logit_u_list if x is not None]
+ logp_u_list = torch.cat(logp_u_list)
+ targ_u_list = torch.cat(targ_u_list)
+
+ loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
+ loss_u_list.append(loss_u)
+ logging_output[f"loss_u_0"] = loss_u.detach().item()
+
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
+ if self.pred_nomask_weight > 0:
+ loss += self.pred_nomask_weight * sum(loss_u_list)
+ sample_size += len(targ_u_list)
+
+ if self.loss_weights is not None:
+ extra_losses = []
+ names = []
+ extra_losses.append(features_pen)
+ names.append("features_pen")
+ if torch.is_tensor(extra_losses):
+ extra_losses = [extra_losses]
+ names = [names]
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
+ if coef != 0 and p is not None:
+ p = coef * p.float() * sample_size
+ loss += p
+ logging_output[f"loss_{n}"] = p.item()
+
+ logging_output = {
+ "loss": loss.item() if reduce else loss,
+ **logging_output,
+ }
+
+ # for lk in self.log_keys:
+ # if lk in net_output:
+ # logging_output[lk] = float((net_output[lk]))
+
+ def compute_correct(logits, target):
+ if logits.numel() == 0:
+ return 0, 0
+ else:
+ assert logits.dim() > 1, logits.shape
+ max = logits.argmax(-1) == target
+ min = logits.argmin(-1) == target
+ both = max & min
+ corr = max.long().sum().item() - both.long().sum().item()
+ count = max.numel()
+ return corr, count
+
+ with torch.no_grad():
+ corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
+ logging_output[f"correct_m_0"] = corr_m
+ logging_output[f"count_m_0"] = count_m
+
+ corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
+ logging_output[f"correct_u_0"] = corr_u
+ logging_output[f"count_u_0"] = count_u
+
+ return loss, sample_size, logging_output
diff --git a/egs/librispeech/SSL/hubert/joiner.py b/egs/librispeech/SSL/hubert/joiner.py
new file mode 120000
index 000000000..aa3362cda
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/joiner.py
@@ -0,0 +1 @@
+../../ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py
new file mode 100644
index 000000000..46a968b69
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/model.py
@@ -0,0 +1,344 @@
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from scaling import ScaledLinear
+
+from icefall.utils import add_sos
+
+
+class AsrModel(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder: Optional[nn.Module] = None,
+ joiner: Optional[nn.Module] = None,
+ encoder_dim: int = 768,
+ decoder_dim: int = 512,
+ vocab_size: int = 500,
+ use_transducer: bool = True,
+ use_ctc: bool = False,
+ ):
+ """A joint CTC & Transducer ASR model.
+
+ - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
+ - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
+ - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
+
+ Args:
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ inputs: `x` of (N, T, encoder_dim).
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, decoder_dim).
+ It should contain one attribute: `blank_id`.
+ It is used when use_transducer is True.
+ joiner:
+ It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+ Its output shape is (N, T, U, vocab_size). Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ It is used when use_transducer is True.
+ use_transducer:
+ Whether use transducer head. Default: True.
+ use_ctc:
+ Whether use CTC head. Default: False.
+ """
+ super().__init__()
+
+ assert (
+ use_transducer or use_ctc
+ ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
+
+ self.encoder = encoder
+
+ self.use_transducer = use_transducer
+ if use_transducer:
+ # Modules for Transducer head
+ assert decoder is not None
+ assert hasattr(decoder, "blank_id")
+ assert joiner is not None
+
+ self.decoder = decoder
+ self.joiner = joiner
+
+ self.simple_am_proj = ScaledLinear(
+ encoder_dim, vocab_size, initial_scale=0.25
+ )
+ self.simple_lm_proj = ScaledLinear(
+ decoder_dim, vocab_size, initial_scale=0.25
+ )
+ else:
+ assert decoder is None
+ assert joiner is None
+
+ self.use_ctc = use_ctc
+ if use_ctc:
+ # Modules for CTC head
+ self.ctc_output = nn.Sequential(
+ nn.Dropout(p=0.1),
+ nn.Linear(encoder_dim, vocab_size),
+ nn.LogSoftmax(dim=-1),
+ )
+
+ def forward_encoder(
+ self,
+ x: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute encoder outputs.
+ Args:
+ x:
+ A 2-D tensor of shape (N, T).
+
+ Returns:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ """
+ if padding_mask is None:
+ padding_mask = torch.zeros_like(x, dtype=torch.bool)
+
+ encoder_out, padding_mask = self.encoder.extract_features(
+ source=x,
+ padding_mask=padding_mask,
+ mask=self.encoder.training,
+ )
+ encoder_out_lens = torch.sum(~padding_mask, dim=1)
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def forward_ctc(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ targets: torch.Tensor,
+ target_lengths: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+ Args:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ targets:
+ Target Tensor of shape (sum(target_lengths)). The targets are assumed
+ to be un-padded and concatenated within 1 dimension.
+ """
+ # Compute CTC log-prob
+ ctc_output = self.ctc_output(encoder_out) # (N, T, C)
+
+ ctc_loss = torch.nn.functional.ctc_loss(
+ log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
+ targets=targets,
+ input_lengths=encoder_out_lens,
+ target_lengths=target_lengths,
+ reduction="sum",
+ )
+ return ctc_loss
+
+ def forward_transducer(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ y_lens: torch.Tensor,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute Transducer loss.
+ Args:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ """
+ # Now for the decoder, i.e., the prediction network
+ blank_id = self.decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ # sos_y_padded: [B, S + 1], start with SOS.
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+ # decoder_out: [B, S + 1, decoder_dim]
+ decoder_out = self.decoder(sos_y_padded)
+
+ # Note: y does not start with SOS
+ # y_padded : [B, S]
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (encoder_out.size(0), 4),
+ dtype=torch.int64,
+ device=encoder_out.device,
+ )
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = encoder_out_lens
+
+ lm = self.simple_lm_proj(decoder_out)
+ am = self.simple_am_proj(encoder_out)
+
+ # if self.training and random.random() < 0.25:
+ # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
+ # if self.training and random.random() < 0.25:
+ # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+ lm=lm.float(),
+ am=am.float(),
+ symbols=y_padded,
+ termination_symbol=blank_id,
+ lm_only_scale=lm_scale,
+ am_only_scale=am_scale,
+ boundary=boundary,
+ reduction="sum",
+ return_grad=True,
+ )
+
+ # ranges : [B, T, prune_range]
+ ranges = k2.get_rnnt_prune_ranges(
+ px_grad=px_grad,
+ py_grad=py_grad,
+ boundary=boundary,
+ s_range=prune_range,
+ )
+
+ # am_pruned : [B, T, prune_range, encoder_dim]
+ # lm_pruned : [B, T, prune_range, decoder_dim]
+ am_pruned, lm_pruned = k2.do_rnnt_pruning(
+ am=self.joiner.encoder_proj(encoder_out),
+ lm=self.joiner.decoder_proj(decoder_out),
+ ranges=ranges,
+ )
+
+ # logits : [B, T, prune_range, vocab_size]
+
+ # project_input=False since we applied the decoder's input projections
+ # prior to do_rnnt_pruning (this is an optimization for speed).
+ logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ pruned_loss = k2.rnnt_loss_pruned(
+ logits=logits.float(),
+ symbols=y_padded,
+ ranges=ranges,
+ termination_symbol=blank_id,
+ boundary=boundary,
+ reduction="sum",
+ )
+
+ return simple_loss, pruned_loss
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ y: k2.RaggedTensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (N, T).
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ Returns:
+ Return the transducer losses and CTC loss,
+ in form of (simple_loss, pruned_loss, ctc_loss)
+
+ Note:
+ Regarding am_scale & lm_scale, it will make the loss-function one of
+ the form:
+ lm_scale * lm_probs + am_scale * am_probs +
+ (1-lm_scale-am_scale) * combined_probs
+ """
+ assert x.ndim == 2, x.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == y.dim0, (x.shape, y.dim0)
+
+ # Compute encoder outputs
+ encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
+
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ if self.use_transducer:
+ # Compute transducer loss
+ simple_loss, pruned_loss = self.forward_transducer(
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ y=y.to(x.device),
+ y_lens=y_lens,
+ prune_range=prune_range,
+ am_scale=am_scale,
+ lm_scale=lm_scale,
+ )
+ else:
+ simple_loss = torch.empty(0)
+ pruned_loss = torch.empty(0)
+
+ if self.use_ctc:
+ # Compute CTC loss
+ targets = y.values
+ ctc_loss = self.forward_ctc(
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ targets=targets,
+ target_lengths=y_lens,
+ )
+ else:
+ ctc_loss = torch.empty(0)
+
+ return simple_loss, pruned_loss, ctc_loss, encoder_out_lens
diff --git a/egs/librispeech/SSL/hubert/optim.py b/egs/librispeech/SSL/hubert/optim.py
new file mode 120000
index 000000000..56b827b8a
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/optim.py
@@ -0,0 +1 @@
+../../ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py
new file mode 100644
index 000000000..d9bda8857
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/pretrain.py
@@ -0,0 +1,1082 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For hubert model pretraining:
+./hubert/pretrain.py \
+ --world-size 8 \
+ --num-epochs 400 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 1 \
+ --max-duration 87.5 \
+ --accum-grad 4
+"""
+
+
+import argparse
+import copy
+import logging
+import sys
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from hubert import HubertModel, add_hubert_arguments
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from ssl_datamodule import LibriSpeechDataModule
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.functional import pad
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=400,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--warmup-batches",
+ type=float,
+ default=5000,
+ help="Eden warmup steps",
+ )
+
+ parser.add_argument(
+ "--warmup-start",
+ type=float,
+ default=0,
+ help="Eden warmup start learning rate",
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=80,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=4,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--max-keep-size",
+ type=int,
+ default=sys.maxsize,
+ help="exclude sample longer than this.",
+ )
+
+ parser.add_argument(
+ "--min-keep-size",
+ type=float,
+ default=32000,
+ help="exclude sample longer less than this.",
+ )
+
+ parser.add_argument(
+ "--max-sample-size",
+ type=float,
+ default=250000,
+ help="max sample size to crop to for batching.",
+ )
+
+ add_hubert_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ model = HubertModel(params)
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+ kmeans = batch["kmeans"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss, num_masked_tokens, logging_output = model(
+ source=audio, target_list=[kmeans], padding_mask=padding_mask
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_masked_tokens
+ for item in logging_output:
+ info[item] = logging_output[item]
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = batch["kmeans"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(
+ optimizer,
+ params.lr_batches,
+ params.lr_epochs,
+ params.warmup_batches,
+ params.warmup_start,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if (
+ c.duration < params.min_keep_size / params.sample_rate
+ or c.duration > params.max_keep_size / params.sample_rate
+ ):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ # valid_cuts += librispeech.dev_other_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py
new file mode 100644
index 000000000..24c0d4d3a
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/pretrain_ce.py
@@ -0,0 +1,1082 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For hubert model pretraining:
+./hubert/pretrain.py \
+ --world-size 8 \
+ --num-epochs 400 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 1 \
+ --max-duration 87.5 \
+ --accum-grad 4
+"""
+
+
+import argparse
+import copy
+import logging
+import sys
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from hubert_ce import HubertModel, add_hubert_arguments
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from ssl_datamodule import LibriSpeechDataModule
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.functional import pad
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=400,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--warmup-batches",
+ type=float,
+ default=5000,
+ help="Eden warmup steps",
+ )
+
+ parser.add_argument(
+ "--warmup-start",
+ type=float,
+ default=0,
+ help="Eden warmup start learning rate",
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=80,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=4,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--max-keep-size",
+ type=int,
+ default=sys.maxsize,
+ help="exclude sample longer than this.",
+ )
+
+ parser.add_argument(
+ "--min-keep-size",
+ type=float,
+ default=32000,
+ help="exclude sample longer less than this.",
+ )
+
+ parser.add_argument(
+ "--max-sample-size",
+ type=float,
+ default=250000,
+ help="max sample size to crop to for batching.",
+ )
+
+ add_hubert_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ model = HubertModel(params)
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+ kmeans = batch["kmeans"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss, num_masked_tokens, logging_output = model(
+ source=audio, target_list=[kmeans], padding_mask=padding_mask
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_masked_tokens
+ for item in logging_output:
+ info[item] = logging_output[item]
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = batch["kmeans"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(
+ optimizer,
+ params.lr_batches,
+ params.lr_epochs,
+ params.warmup_batches,
+ params.warmup_start,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if (
+ c.duration < params.min_keep_size / params.sample_rate
+ or c.duration > params.max_keep_size / params.sample_rate
+ ):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ # valid_cuts += librispeech.dev_other_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/hubert/scaling.py b/egs/librispeech/SSL/hubert/scaling.py
new file mode 120000
index 000000000..e30bd99de
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/scaling.py
@@ -0,0 +1 @@
+../../ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/hubert/ssl_datamodule.py b/egs/librispeech/SSL/hubert/ssl_datamodule.py
new file mode 100644
index 000000000..ac1a0997d
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/ssl_datamodule.py
@@ -0,0 +1,341 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from dataset import HubertDataset
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriSpeechDataModule:
+ """
+ DataModule for SSL experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in SSL
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+
+ This class should be derived for specific corpora used in SSL tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="SSL data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies.",
+ )
+ group.add_argument(
+ "--full-libri",
+ type=str2bool,
+ default=True,
+ help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/kmeans"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=float,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+ group.add_argument(
+ "--do-normalize",
+ type=str2bool,
+ default=True,
+ help="whether to normalize the data",
+ )
+ group.add_argument(
+ "--random-crop",
+ type=str2bool,
+ default=True,
+ help="always crop from the beginning if false",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ max_sample_size: Optional[int] = None,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = HubertDataset(
+ max_sample_size=max_sample_size,
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(
+ self,
+ cuts_valid: CutSet,
+ max_sample_size: Optional[int] = None,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ ) -> DataLoader:
+ logging.info("About to create dev dataset")
+ validate = HubertDataset(
+ max_sample_size=max_sample_size,
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(
+ self,
+ cuts: CutSet,
+ sample_rate: float = 16000,
+ label_rate: float = 50,
+ random_crop: bool = True,
+ pad_audio: bool = False,
+ num_classes: list = [504],
+ do_normalize: bool = True,
+ ) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = HubertDataset(
+ sample_rate=sample_rate,
+ label_rate=label_rate,
+ random_crop=random_crop,
+ pad_audio=pad_audio,
+ num_classes=num_classes,
+ do_normalize=do_normalize,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_clean_100_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-100 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_clean_360_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-360 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_other_500_cuts(self) -> CutSet:
+ logging.info("About to get train-other-500 cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_all_shuf_cuts(self) -> CutSet:
+ logging.info(
+ "About to get the shuffled train-clean-100, \
+ train-clean-360 and train-other-500 cuts"
+ )
+ train_clean_100_cuts = self.train_clean_100_cuts()
+ train_clean_360_cuts = self.train_clean_360_cuts()
+ train_other_500_cuts = self.train_other_500_cuts()
+ return CutSet.mux(
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ weights=[
+ 28539, # len(train_clean_100_cuts)
+ 104014, # len(train_clean_360_cuts)
+ 148688, # len(train_other_500_cuts)
+ ],
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get test-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get test-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
+ )
diff --git a/egs/librispeech/SSL/hubert/utils.py b/egs/librispeech/SSL/hubert/utils.py
new file mode 100644
index 000000000..de980ba62
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/utils.py
@@ -0,0 +1,338 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+from typing import Callable, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def relu_squared(x: torch.Tensor):
+ return F.relu(x).pow(2)
+
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+ )
+
+
+def is_xla_tensor(tensor):
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
+
+
+def index_put(tensor, indices, value):
+ if is_xla_tensor(tensor):
+ for _ in range(indices.dim(), tensor.dim()):
+ indices = indices.unsqueeze(-1)
+ if indices.size(-1) < tensor.size(-1):
+ indices = indices.expand_as(tensor)
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
+ else:
+ tensor[indices] = value
+ return tensor
+
+
+def pad_to_multiple(x, multiple, dim=-1, value=0):
+ # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
+ if x is None:
+ return None, 0
+ tsz = x.size(dim)
+ m = tsz / multiple
+ remainder = math.ceil(m) * multiple - tsz
+ if m.is_integer():
+ return x, 0
+ pad_offset = (0,) * (-1 - dim) * 2
+
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str) -> Callable:
+ """Returns the activation function corresponding to `activation`"""
+ if activation == "relu":
+ return F.relu
+ elif activation == "relu_squared":
+ return relu_squared
+ elif activation == "gelu":
+ return gelu
+ elif activation == "gelu_fast":
+ return gelu_accurate
+ elif activation == "gelu_accurate":
+ return gelu_accurate
+ elif activation == "tanh":
+ return torch.tanh
+ elif activation == "linear":
+ return lambda x: x
+ elif activation == "swish":
+ return torch.nn.SiLU
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class SamePad(nn.Module):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class SamePad2d(nn.Module):
+ def __init__(self, kernel_size):
+ super().__init__()
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ assert len(x.size()) == 4
+ if self.remove > 0:
+ x = x[:, :, : -self.remove, : -self.remove]
+ return x
+
+
+class TransposeLast(nn.Module):
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
+ super().__init__()
+ self.deconstruct_idx = deconstruct_idx
+ self.tranpose_dim = tranpose_dim
+
+ def forward(self, x):
+ if self.deconstruct_idx is not None:
+ x = x[self.deconstruct_idx]
+ return x.transpose(self.tranpose_dim, -1)
+
+
+try:
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
+
+ has_fused_layernorm = True
+
+ class FusedLayerNorm(_FusedLayerNorm):
+ @torch.jit.unused
+ def forward(self, x):
+ if not x.is_cuda:
+ return super().forward(x)
+ else:
+ with torch.cuda.device(x.device):
+ return super().forward(x)
+
+except ImportError:
+ has_fused_layernorm = False
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ export = True
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.group_norm(
+ input.float(),
+ self.num_groups,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+def softmax(x, dim: int, onnx_trace: bool = False):
+ if onnx_trace:
+ return F.softmax(x.float(), dim=dim)
+ else:
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def quant_noise(module, p, block_size):
+ """
+ Wraps modules and applies quantization noise to the weights for
+ subsequent quantization with Iterative Product Quantization as
+ described in "Training with Quantization Noise for Extreme Model Compression"
+
+ Args:
+ - module: nn.Module
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights,
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper
+ which consists in randomly dropping blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ assert (
+ module.weight.size(1) % block_size == 0
+ ), "Input features must be a multiple of block sizes"
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ assert (
+ module.in_channels % block_size == 0
+ ), "Input channels must be a multiple of block sizes"
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = torch.zeros(
+ in_features // block_size * out_features,
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(
+ weight.size(0), weight.size(1), device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = (
+ mask.unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ )
+
+ # scale weights and apply mask
+ mask = mask.to(
+ torch.bool
+ ) # x.bool() is not currently supported in TorchScript
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class FairseqDropout(nn.Module):
+ def __init__(self, p, module_name=None):
+ super().__init__()
+ self.p = p
+ self.module_name = module_name
+ self.apply_during_inference = False
+
+ def forward(self, x, inplace: bool = False):
+ if self.p > 0 and (self.training or self.apply_during_inference):
+ return F.dropout(x, p=self.p, training=True, inplace=inplace)
+ else:
+ return x
+
+ def make_generation_fast_(
+ self,
+ name: str,
+ retain_dropout: bool = False,
+ retain_dropout_modules: Optional[List[str]] = None,
+ **kwargs
+ ):
+ if retain_dropout:
+ if retain_dropout_modules is not None and self.module_name is None:
+ pass
+ elif (
+ retain_dropout_modules is None # if None, apply to all modules
+ or self.module_name in retain_dropout_modules
+ ):
+ self.apply_during_inference = True
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
diff --git a/egs/librispeech/SSL/hubert/wav2vec2_module.py b/egs/librispeech/SSL/hubert/wav2vec2_module.py
new file mode 100644
index 000000000..4c2e1ce98
--- /dev/null
+++ b/egs/librispeech/SSL/hubert/wav2vec2_module.py
@@ -0,0 +1,593 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+from typing import List, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from attention_module import MultiheadAttention, init_bert_params
+from utils import (
+ Fp32GroupNorm,
+ Fp32LayerNorm,
+ LayerNorm,
+ SamePad,
+ TransposeLast,
+ get_activation_fn,
+ index_put,
+ pad_to_multiple,
+)
+
+
+class ConvFeatureExtractionModel(nn.Module):
+ def __init__(
+ self,
+ conv_layers: List[Tuple[int, int, int]],
+ dropout: float = 0.0,
+ mode: str = "default",
+ conv_bias: bool = False,
+ ):
+ super().__init__()
+
+ assert mode in {"default", "layer_norm"}
+
+ def block(
+ n_in,
+ n_out,
+ k,
+ stride,
+ is_layer_norm=False,
+ is_group_norm=False,
+ conv_bias=False,
+ ):
+ def make_conv():
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+ nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert (
+ is_layer_norm and is_group_norm
+ ) == False, "layer norm and group norm are exclusive"
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.Sequential(
+ TransposeLast(),
+ Fp32LayerNorm(dim, elementwise_affine=True),
+ TransposeLast(),
+ ),
+ nn.GELU(),
+ )
+ elif is_group_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ Fp32GroupNorm(dim, dim, affine=True),
+ nn.GELU(),
+ )
+ else:
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode == "layer_norm",
+ is_group_norm=mode == "default" and i == 0,
+ conv_bias=conv_bias,
+ )
+ )
+ in_d = dim
+
+ def forward(self, x):
+ # BxT -> BxCxT
+ x = x.unsqueeze(1)
+
+ for conv in self.conv_layers:
+ x = conv(x)
+
+ return x
+
+
+def make_conv_pos(e, k, g, is_batch_norm=False):
+ pos_conv = nn.Conv1d(
+ e,
+ e,
+ kernel_size=k,
+ padding=k // 2,
+ groups=g,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (k * e))
+ nn.init.normal_(pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(pos_conv.bias, 0)
+
+ if not is_batch_norm:
+ pos_conv = nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2)
+ pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
+ else:
+ batch_norm = nn.BatchNorm1d(e)
+ pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())
+
+ return pos_conv
+
+
+class TransformerEncoder(nn.Module):
+ def build_encoder_layer(self, args, **kwargs):
+ if args.layer_type == "transformer":
+ layer = TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ )
+ elif args.layer_type == "trf_adp":
+ use_adp = False
+ if args.adp_trf_idx == "all":
+ use_adp = True
+ else:
+ adp_trf_idx = list(
+ range(*[int(g) for g in args.adp_trf_idx.split(":")])
+ )
+ if kwargs.get("layer_idx", None) in adp_trf_idx:
+ use_adp = True
+ if use_adp:
+ layer = TransformerSentenceEncoderWithAdapterLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ adapter_num=args.adp_num,
+ adapter_dim=args.adp_dim,
+ adapter_act_fn=args.adp_act_fn,
+ )
+ else:
+ layer = TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ )
+
+ # layer = fsdp_wrap(layer)
+ # if args.checkpoint_activations:
+ # layer = checkpoint_wrapper(layer)
+ return layer
+
+ def __init__(self, args):
+ super().__init__()
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+ self.required_seq_len_multiple = args.required_seq_len_multiple
+
+ pos_conv_depth = getattr(args, "pos_conv_depth", 1)
+ if pos_conv_depth > 1:
+ num_layers = args.pos_conv_depth
+ k = max(3, args.conv_pos // num_layers)
+
+ def make_conv_block(e, k, g, l):
+ return nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.Conv1d(
+ e,
+ e,
+ kernel_size=k,
+ padding=k // 2,
+ groups=g,
+ ),
+ SamePad(k),
+ TransposeLast(),
+ LayerNorm(e, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ )
+ for _ in range(l)
+ ]
+ )
+
+ self.pos_conv = make_conv_block(
+ self.embedding_dim, k, args.conv_pos_groups, num_layers
+ )
+
+ else:
+ self.pos_conv = make_conv_pos(
+ self.embedding_dim,
+ args.conv_pos,
+ args.conv_pos_groups,
+ is_batch_norm=args.conv_pos_batch_norm
+ if hasattr(args, "conv_pos_batch_norm")
+ else False,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ self.build_encoder_layer(args, layer_idx=ii)
+ for ii in range(args.encoder_layers)
+ ]
+ )
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
+ x, layer_results = self.extract_features(
+ x, padding_mask, layer, corpus_key=corpus_key
+ )
+
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(
+ self,
+ x,
+ padding_mask=None,
+ tgt_layer=None,
+ min_layer=0,
+ corpus_key=None,
+ ):
+ if padding_mask is not None:
+ x = index_put(x, padding_mask, 0)
+
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x = x + x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ # pad to the sequence length dimension
+ x, pad_length = pad_to_multiple(
+ x, self.required_seq_len_multiple, dim=-2, value=0
+ )
+ if pad_length > 0 and padding_mask is None:
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
+ padding_mask[:, -pad_length:] = True
+ else:
+ padding_mask, _ = pad_to_multiple(
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
+ )
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ layer_results = []
+ r = None
+
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
+ if not self.training or (dropout_probability > self.layerdrop):
+ layer_check = layer
+ # if isinstance(layer, FullyShardedDataParallel):
+ # layer_check = layer.unwrapped_module
+ if (corpus_key is None) or (
+ not isinstance(
+ layer_check,
+ (TransformerSentenceEncoderWithAdapterLayer,),
+ )
+ ):
+ x, (z, lr) = layer(
+ x,
+ self_attn_padding_mask=padding_mask,
+ need_weights=False,
+ )
+ else:
+ x, (z, lr) = layer(
+ x,
+ self_attn_padding_mask=padding_mask,
+ need_weights=False,
+ corpus_key=corpus_key,
+ )
+ if i >= min_layer:
+ layer_results.append((x, z, lr))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ # undo paddding
+ if pad_length > 0:
+ x = x[:, :-pad_length]
+
+ def undo_pad(a, b, c):
+ return (
+ a[:-pad_length],
+ b[:-pad_length] if b is not None else b,
+ c[:-pad_length],
+ )
+
+ layer_results = [undo_pad(*u) for u in layer_results]
+
+ return x, layer_results
+
+ def max_positions(self):
+ """Maximum output length supported by the encoder."""
+ return self.args.max_positions
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+ return state_dict
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: int = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ ) -> None:
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ att_args=None,
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ attn_mask=self_attn_mask,
+ need_weights=False,
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+
+ layer_result = x
+
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+
+ layer_result = x
+
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, (attn, layer_result)
+
+
+class AdapterFast(nn.Module):
+ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
+ """
+ Implements adapter modules directly with 3D tensor weight as parameters
+ and without using ModuleList orto speed up training throughput.
+ """
+ super().__init__()
+
+ self.adapter_num = adapter_num
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
+ self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
+ self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
+ self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
+
+ self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
+ self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
+ self.act_fn = nn.Identity()
+ if act_fn == "relu":
+ self.act_fn = nn.ReLU()
+ elif act_fn == "gelu":
+ self.act_fn = nn.GELU()
+ elif act_fn == "selu":
+ self.act_fn = nn.SELU()
+ else:
+ raise ValueError(f"unsupported {act_fn}")
+
+ self.input_dim = input_dim
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for ii in range(self.adapter_num):
+ nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
+ nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(self.b_a[ii], -bound, bound)
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(self.b_b[ii], -bound, bound)
+
+ nn.init.ones_(self.ln_W)
+ nn.init.zeros_(self.ln_b)
+
+ def forward(self, x, adapter_id):
+ ii = adapter_id
+ h = x
+ h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii])
+ h = F.linear(h, self.W_a[ii], self.b_a[ii])
+ h = self.act_fn(h)
+ h = F.linear(h, self.W_b[ii], self.b_b[ii])
+ outputs = h
+ return outputs
+
+ def extra_repr(self):
+ return "adapter={}, input_dim={}, hidden_dim={}".format(
+ self.adapter_num, self.input_dim, self.hidden_dim
+ )
+
+
+class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
+ """
+ Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
+ models. An adapter module is added along with vanilla Transformer module.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: int = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ adapter_num=201,
+ adapter_dim=64,
+ adapter_act_fn="relu",
+ ) -> None:
+ super().__init__(
+ embedding_dim=embedding_dim,
+ ffn_embedding_dim=ffn_embedding_dim,
+ num_attention_heads=num_attention_heads,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ layer_norm_first=layer_norm_first,
+ )
+
+ self.adapter_num = adapter_num
+ self.adapter_dim = adapter_dim
+ self.adapter_layer = AdapterFast(
+ adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ att_args=None,
+ corpus_key=None,
+ ):
+ x, (attn, layer_result) = super().forward(
+ x=x,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ att_args=att_args,
+ )
+ assert corpus_key is not None
+ assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
+ y = self.adapter_layer(x, corpus_key[0])
+ x = x + y
+ return x, (attn, layer_result)
diff --git a/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py b/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py
new file mode 100644
index 000000000..aa2f45f75
--- /dev/null
+++ b/egs/librispeech/SSL/local/attach_kmeans_to_supervisions.py
@@ -0,0 +1,52 @@
+import os
+
+import jsonlines
+from tqdm import tqdm
+
+os.system(
+ "cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_dev-clean* ."
+)
+os.system(
+ "cp /userhome/user/yfy62/librispeech_data/data4ssl/manifests/librispeech_*_train* ."
+)
+os.system("chmod -R 644 *.jsonl.gz")
+os.system("gunzip *.gz")
+
+dataset_parts = (
+ "dev-clean",
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+)
+
+kmeans_dir = "/userhome/user/yangguanrou/data/k500"
+idx_dir = "/userhome/user/yangguanrou/data/shu"
+
+kmeans = []
+idxs = []
+for part in ["train", "valid"]:
+ with open(kmeans_dir + "/" + part + ".km", "r") as f:
+ kmeans += f.read().splitlines()
+
+ with open(idx_dir + "/" + part + ".tsv", "r") as f:
+ lines = f.read().splitlines()
+ idxs += [
+ line.split("\t", -1)[0].split("/", -1)[-1].replace(".flac", "")
+ for line in lines
+ if ".flac" in line
+ ]
+
+idx2kmeans = {}
+for idx, km in zip(idxs, kmeans):
+ idx2kmeans[idx] = km
+
+for part in dataset_parts:
+ with jsonlines.open(f"librispeech_supervisions_{part}.jsonl") as reader:
+ with jsonlines.open(
+ f"librispeech_supervisions_{part}_new.jsonl", mode="w"
+ ) as writer:
+ for obj in tqdm(reader):
+ obj["custom"] = {"kmeans": idx2kmeans[obj["id"]]}
+ writer.write(obj)
+
+os.system('for file in *_new.jsonl; do mv "$file" "${file%_new.jsonl}.jsonl"; done')
diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py
new file mode 100644
index 000000000..4212cd9c6
--- /dev/null
+++ b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py
@@ -0,0 +1,18 @@
+# simple script to convert a fairseq checkpoint into pytorch parameter state dict
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+
+parser = ArgumentParser()
+parser.add_argument("--src")
+parser.add_argument("--tgt")
+
+args = parser.parse_args()
+src = args.src
+tgt = args.tgt
+
+old_checkpoint = torch.load(src)
+new_checkpoint = OrderedDict()
+new_checkpoint["model"] = old_checkpoint["model"]
+torch.save(new_checkpoint, tgt)
diff --git a/egs/librispeech/SSL/local/prepare_char.py b/egs/librispeech/SSL/local/prepare_char.py
new file mode 100644
index 000000000..8cc0502c2
--- /dev/null
+++ b/egs/librispeech/SSL/local/prepare_char.py
@@ -0,0 +1,259 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+
+This script takes as input `lang_dir`, which should contain::
+
+ - lang_dir/text,
+ - lang_dir/words.txt
+
+and generates the following files in the directory `lang_dir`:
+
+ - lexicon.txt
+ - lexicon_disambig.txt
+ - L.pt
+ - L_disambig.pt
+ - tokens.txt
+"""
+
+import argparse
+import re
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+from prepare_lang import (
+ Lexicon,
+ add_disambig_symbols,
+ add_self_loops,
+ write_lexicon,
+ write_mapping,
+)
+
+
+def lexicon_to_fst_no_sil(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format).
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ loop_state = 0 # words enter and leave from here
+ next_state = 1 # the next un-allocated state, will be incremented as we go
+
+ arcs = []
+
+ # The blank symbol is defined in local/train_bpe_model.py
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ for word, pieces in lexicon:
+ assert len(pieces) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces]
+
+ for i in range(len(pieces) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, pieces[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last piece of this word
+ i = len(pieces) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, pieces[i], w, 0])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
+ """Check if all the given tokens are in token symbol table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that contains all the valid tokens.
+ tokens:
+ A list of tokens.
+ Returns:
+ Return True if there is any token not in the token_sym_table,
+ otherwise False.
+ """
+ for tok in tokens:
+ if tok not in token_sym_table:
+ return True
+ return False
+
+
+def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon:
+ """Generate a lexicon from a word list and token_sym_table.
+
+ Args:
+ token_sym_table:
+ Token symbol table that mapping token to token ids.
+ words:
+ A list of strings representing words.
+ Returns:
+ Return a dict whose keys are words and values are the corresponding
+ tokens.
+ """
+ lexicon = []
+ for word in words:
+ chars = list(word.strip(" \t"))
+ if contain_oov(token_sym_table, chars):
+ continue
+ lexicon.append((word, chars))
+
+ # The OOV word is
+ lexicon.append(("", [""]))
+ return lexicon
+
+
+def generate_tokens(text_file: str) -> Dict[str, int]:
+ """Generate tokens from the given text file.
+
+ Args:
+ text_file:
+ A file that contains text lines to generate tokens.
+ Returns:
+ Return a dict whose keys are tokens and values are token ids ranged
+ from 0 to len(keys) - 1.
+ """
+ tokens: Dict[str, int] = dict()
+ tokens[""] = 0
+ tokens[""] = 1
+ tokens[""] = 2
+ whitespace = re.compile(r"([ \t\r\n]+)")
+ with open(text_file, "r", encoding="utf-8") as f:
+ for line in f:
+ line = re.sub(whitespace, "", line)
+ chars = list(line)
+ for char in chars:
+ if char not in tokens:
+ tokens[char] = len(tokens)
+ return tokens
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain the bpe.model and words.txt
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ text_file = lang_dir / "text"
+
+ word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+
+ words = word_sym_table.symbols
+
+ excluded = ["", "!SIL", "", "", "#0", "", ""]
+ for w in excluded:
+ if w in words:
+ words.remove(w)
+
+ token_sym_table = generate_tokens(text_file)
+
+ lexicon = generate_lexicon(token_sym_table, words)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ next_token_id = max(token_sym_table.values()) + 1
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in token_sym_table
+ token_sym_table[disambig] = next_token_id
+ next_token_id += 1
+
+ word_sym_table.add("#0")
+ word_sym_table.add("")
+ word_sym_table.add("")
+
+ write_mapping(lang_dir / "tokens.txt", token_sym_table)
+
+ write_lexicon(lang_dir / "lexicon.txt", lexicon)
+ write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst_no_sil(
+ lexicon,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ )
+
+ L_disambig = lexicon_to_fst_no_sil(
+ lexicon_disambig,
+ token2id=token_sym_table,
+ word2id=word_sym_table,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), lang_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/local/prepare_lang.py b/egs/librispeech/SSL/local/prepare_lang.py
new file mode 100644
index 000000000..c8cf9b881
--- /dev/null
+++ b/egs/librispeech/SSL/local/prepare_lang.py
@@ -0,0 +1,388 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ assert token2id[""] == 0
+ assert word2id[""] == 0
+
+ eps = 0
+
+ sil_token = token2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [token2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = token2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone")
+ return parser.parse_args()
+
+
+def main():
+ out_dir = Path(get_args().lang_dir)
+ lexicon_filename = out_dir / "lexicon.txt"
+ sil_token = "SIL"
+ sil_prob = 0.5
+
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_tokens(lexicon)
+ words = get_words(lexicon)
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ assert "" not in tokens
+ tokens = [""] + tokens
+
+ assert "" not in words
+ assert "#0" not in words
+ assert "" not in words
+ assert "" not in words
+
+ words = [""] + words + ["#0", "", ""]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / "tokens.txt", token2id)
+ write_mapping(out_dir / "words.txt", word2id)
+ write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=token2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / "L.pt")
+ torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/local/process_librispeech4finetune.py b/egs/librispeech/SSL/local/process_librispeech4finetune.py
new file mode 100644
index 000000000..09f4b8a3e
--- /dev/null
+++ b/egs/librispeech/SSL/local/process_librispeech4finetune.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from lhotse import CutSet
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ return parser.parse_args()
+
+
+def process_wav_librispeech(
+ dataset: Optional[str] = None,
+):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/wav")
+
+ if dataset is None:
+ dataset_parts = (
+ "dev-clean",
+ "dev-other",
+ "test-clean",
+ "test-other",
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ prefix = "librispeech"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ process_wav_librispeech(
+ dataset=args.dataset,
+ )
diff --git a/egs/librispeech/SSL/local/process_librispeech4pretrain.py b/egs/librispeech/SSL/local/process_librispeech4pretrain.py
new file mode 100644
index 000000000..c375a2df3
--- /dev/null
+++ b/egs/librispeech/SSL/local/process_librispeech4pretrain.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from lhotse import CutSet
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ return parser.parse_args()
+
+
+def process_kmeans_librispeech(
+ dataset: Optional[str] = None,
+):
+ src_dir = Path(".")
+ output_dir = Path(".")
+
+ if dataset is None:
+ dataset_parts = (
+ "dev-clean",
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ prefix = "librispeech"
+ suffix = "jsonl"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ process_kmeans_librispeech(
+ dataset=args.dataset,
+ )
diff --git a/egs/librispeech/SSL/local/process_raw_cuts.py b/egs/librispeech/SSL/local/process_raw_cuts.py
new file mode 100644
index 000000000..9d2ee5945
--- /dev/null
+++ b/egs/librispeech/SSL/local/process_raw_cuts.py
@@ -0,0 +1,23 @@
+import os
+
+import jsonlines
+from tqdm import tqdm
+
+dataset_parts = (
+ "dev-clean",
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+)
+
+for part in dataset_parts:
+ with jsonlines.open(f"librispeech_cuts_{part}_raw.jsonl") as reader:
+ with jsonlines.open(f"librispeech_cuts_{part}.jsonl", mode="w") as writer:
+ for obj in tqdm(reader):
+ obj["custom"] = {"kmeans": obj["supervisions"][0]["custom"]["kmeans"]}
+ del obj["supervisions"][0]["custom"]
+
+ writer.write(obj)
+
+os.system("rm *_raw.jsonl")
+os.system("gzip *.jsonl")
diff --git a/egs/librispeech/SSL/shared b/egs/librispeech/SSL/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/librispeech/SSL/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/asr_datamodule.py b/egs/librispeech/SSL/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..21a701163
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../hubert/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/beam_search.py b/egs/librispeech/SSL/zipformer/beam_search.py
new file mode 120000
index 000000000..f4d4b5732
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/dataset.py b/egs/librispeech/SSL/zipformer/dataset.py
new file mode 120000
index 000000000..cb5aedde1
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/dataset.py
@@ -0,0 +1 @@
+../hubert/dataset.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/decode.py b/egs/librispeech/SSL/zipformer/decode.py
new file mode 100644
index 000000000..1562c28b8
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/decode.py
@@ -0,0 +1,1043 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(batch["supervisions"]["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["cuts"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ dev_clean_cuts = librispeech.dev_clean_cuts()
+ dev_other_cuts = librispeech.dev_other_cuts()
+
+ dev_clean_dl = librispeech.test_dataloaders(
+ dev_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ dev_other_dl = librispeech.test_dataloaders(
+ dev_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(
+ test_clean_cuts,
+ do_normalize=params.do_normalize,
+ )
+ test_other_dl = librispeech.test_dataloaders(
+ test_other_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
+ test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/zipformer/decoder.py b/egs/librispeech/SSL/zipformer/decoder.py
new file mode 120000
index 000000000..a2138e5da
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/decoder.py
@@ -0,0 +1 @@
+../../ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/encoder_interface.py b/egs/librispeech/SSL/zipformer/encoder_interface.py
new file mode 120000
index 000000000..0afd669f2
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py
new file mode 100644
index 000000000..bbb445320
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/finetune.py
@@ -0,0 +1,1551 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For HuBERT model finetuning:
+./hubert/finetune.py \
+ --world-size 8 \
+ --num-epochs 200 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 0 \
+ --max-duration 1000
+
+It supports finetuning with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from hubert_ce import HubertModel
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ # hubert parameters
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias", type=bool, default=False, help="include bias in conv encoder"
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=222,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="hubert/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--pretrained-dir",
+ type=str,
+ help="""The pretrained model dir.
+ It specifies the directory where the pretrained checkpoint is saved.""",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=1,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for pruned RNN-T loss
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ if hasattr(params, "pretrained_dir"):
+ logging.info(f"Loading {params.pretrained_dir}")
+ pretrained = torch.load(params.pretrained_dir)
+ encoder = HubertModel(params)
+ encoder.load_state_dict(pretrained["model"])
+ else:
+ encoder = HubertModel(params)
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss, num_frames = model(
+ x=audio,
+ padding_mask=padding_mask,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_frames.sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=0)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertAsrDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+ y = sp.encode(batch["supervisions"]["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/zipformer/hubert_ce.py b/egs/librispeech/SSL/zipformer/hubert_ce.py
new file mode 100644
index 000000000..1ac368a1d
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/hubert_ce.py
@@ -0,0 +1,601 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import argparse
+import logging
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scaling import ScheduledFloat
+from utils import GradMultiply, LayerNorm
+from wav2vec2_module import ConvFeatureExtractionModel
+from zipformer import Zipformer2
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
+ add_masks: bool = False,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+ mask_dropout: randomly dropout this percentage of masks in each example
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ if num_mask_ver == 1:
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ if padding_mask is not None:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ num_mask = all_num_mask
+ elif num_mask_ver == 2:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ raise ValueError()
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = rng.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ if mask_type == "static":
+ raise ValueError(f"this should never happens")
+ else:
+ lengths = [min(mask_length, sz - 1)]
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = rng.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError()
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError(
+ (
+ f"the entire sequence is masked. "
+ f"sz={sz}; mask_idc[mask_idc]; "
+ f"index={indices[i] if indices is not None else None}"
+ )
+ )
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ if add_masks:
+ target_len = max([len(m) for m in mask_idcs])
+ else:
+ target_len = min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
+
+ mask[i, mask_idc] = True
+
+ if target_len is not None and len(mask_idc) < target_len:
+ unmasked = np.flatnonzero(~mask[i])
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
+ to_drop = rng.choice(masked, num_holes, replace=False)
+ mask[i, to_drop] = False
+
+ return mask
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+class HubertModel(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ ) -> None:
+ super().__init__()
+ feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / cfg.sample_rate
+ encoder_input_dim = _to_int_tuple(cfg.encoder_dim)[0]
+ encoder_output_dim = max(_to_int_tuple(cfg.encoder_dim))
+ self.post_extract_proj = (
+ nn.Linear(self.embed, encoder_input_dim)
+ if self.embed != encoder_input_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+
+ self.mask_emb = nn.Parameter(torch.FloatTensor(encoder_input_dim).uniform_())
+
+ self.encoder = Zipformer2(
+ output_downsampling_factor=1,
+ downsampling_factor=_to_int_tuple(cfg.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(cfg.num_encoder_layers),
+ encoder_dim=_to_int_tuple(cfg.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(cfg.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(cfg.query_head_dim),
+ pos_head_dim=_to_int_tuple(cfg.pos_head_dim),
+ value_head_dim=_to_int_tuple(cfg.value_head_dim),
+ pos_dim=cfg.pos_dim,
+ num_heads=_to_int_tuple(cfg.num_heads),
+ feedforward_dim=_to_int_tuple(cfg.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(cfg.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ )
+
+ self.layer_norm = LayerNorm(self.embed)
+
+ self.untie_final_proj = cfg.untie_final_proj
+ self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
+
+ # modules below are not needed during fine-tuning
+ self.num_classes = cfg.num_classes
+ self.pred_masked_weight = cfg.pred_masked_weight
+ self.pred_nomask_weight = cfg.pred_nomask_weight
+ self.loss_weights = cfg.loss_weights
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
+
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb.to(x.dtype)
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_features(self, source: torch.Tensor) -> torch.Tensor:
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(
+ self,
+ features: torch.Tensor,
+ target_list: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Trim features to ensure labels exist and then get aligned labels
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+ target_list = [t[:, target_inds.long()] for t in target_list]
+ return features, target_list
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def forward(
+ self,
+ source: torch.Tensor,
+ target_list: Optional[List[torch.Tensor]] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = True,
+ features_only: bool = False,
+ output_layer: Optional[int] = None,
+ ):
+ """output layer is 1-based"""
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+
+ features_pen = features.float().pow(2).mean()
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+ unmasked_features = features.clone()
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x = features
+ mask_indices = None
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float -> (T, B, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x = x.transpose(0, 1)
+ x, x_lens = self.encoder(x, (~padding_mask).sum(dim=-1))
+ x = x.transpose(0, 1)
+
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = self.final_proj(x[masked_indices])
+ proj_x_m /= self.logit_temp
+ logit_m_list = [proj_x_m for _ in range(len(target_list))]
+ else:
+ logit_m_list = [None for _ in target_list]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = self.final_proj(x[nomask_indices])
+ proj_x_u /= self.logit_temp
+ logit_u_list = [proj_x_u for _ in range(len(target_list))]
+ else:
+ logit_u_list = [None for _ in target_list]
+
+ # result = {
+ # "logit_m_list": logit_m_list,
+ # "logit_u_list": logit_u_list,
+ # "padding_mask": padding_mask,
+ # "features_pen": features_pen,
+ # }
+ targ_m_list = target_list[0][masked_indices]
+ targ_m_list = targ_m_list.long()
+ targ_m_list = [targ_m_list for _ in range(len(target_list))]
+
+ targ_u_list = target_list[0][nomask_indices]
+ targ_u_list = targ_u_list.long()
+ targ_u_list = [targ_u_list for _ in range(len(target_list))]
+ return self.compute_loss(
+ logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
+ )
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ res = self.forward(
+ source,
+ padding_mask=padding_mask,
+ mask=mask,
+ features_only=True,
+ output_layer=output_layer,
+ )
+ feature = res["features"] if ret_conv else res["x"]
+ return feature, res["padding_mask"]
+
+ def get_logits(self, net_output, is_masked=True):
+ if is_masked:
+ logits_list = net_output["logit_m_list"]
+ else:
+ logits_list = net_output["logit_u_list"]
+ logits_list = [x.float() for x in logits_list if x is not None]
+ return logits_list
+
+ def get_targets(self, net_output, is_masked=True):
+ logits_list = self.get_logits(net_output, is_masked)
+ targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
+ return targets_list
+
+ def get_extra_losses(self, net_output):
+ extra_losses = []
+ names = []
+
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self):
+ self.final_proj = None
+
+ def compute_loss(
+ self, logit_m_list, logit_u_list, targ_m_list, targ_u_list, features_pen
+ ):
+ loss = 0.0
+ sample_size = 0
+ logging_output = {}
+ reduce = True
+ reduction = "sum" if reduce else "none"
+
+ loss_m_list = []
+ logp_m_list = [x.float() for x in logit_m_list if x is not None]
+ logp_m_list = torch.cat(logp_m_list)
+ targ_m_list = torch.cat(targ_m_list)
+
+ loss_m = F.cross_entropy(logp_m_list, targ_m_list, reduction=reduction)
+ loss_m_list.append(loss_m)
+ logging_output[f"loss_m_0"] = loss_m.detach().item()
+
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
+ if self.pred_masked_weight > 0:
+ loss += self.pred_masked_weight * sum(loss_m_list)
+ sample_size += len(targ_m_list)
+
+ loss_u_list = []
+ logp_u_list = [x.float() for x in logit_u_list if x is not None]
+ logp_u_list = torch.cat(logp_u_list)
+ targ_u_list = torch.cat(targ_u_list)
+
+ loss_u = F.cross_entropy(logp_u_list, targ_u_list, reduction=reduction)
+ loss_u_list.append(loss_u)
+ logging_output[f"loss_u_0"] = loss_u.detach().item()
+
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
+ if self.pred_nomask_weight > 0:
+ loss += self.pred_nomask_weight * sum(loss_u_list)
+ sample_size += len(targ_u_list)
+
+ if self.loss_weights is not None:
+ extra_losses = []
+ names = []
+ extra_losses.append(features_pen)
+ names.append("features_pen")
+ if torch.is_tensor(extra_losses):
+ extra_losses = [extra_losses]
+ names = [names]
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+ assert len(extra_losses) == len(
+ self.loss_weights
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
+ for p, n, coef in zip(extra_losses, names, self.loss_weights):
+ if coef != 0 and p is not None:
+ p = coef * p.float() * sample_size
+ loss += p
+ logging_output[f"loss_{n}"] = p.item()
+
+ logging_output = {
+ "loss": loss.item() if reduce else loss,
+ **logging_output,
+ }
+
+ # for lk in self.log_keys:
+ # if lk in net_output:
+ # logging_output[lk] = float((net_output[lk]))
+
+ def compute_correct(logits, target):
+ if logits.numel() == 0:
+ return 0, 0
+ else:
+ assert logits.dim() > 1, logits.shape
+ max = logits.argmax(-1) == target
+ min = logits.argmin(-1) == target
+ both = max & min
+ corr = max.long().sum().item() - both.long().sum().item()
+ count = max.numel()
+ return corr, count
+
+ with torch.no_grad():
+ corr_m, count_m = compute_correct(logp_m_list, targ_m_list)
+ logging_output[f"correct_m_0"] = corr_m
+ logging_output[f"count_m_0"] = count_m
+
+ corr_u, count_u = compute_correct(logp_u_list, targ_u_list)
+ logging_output[f"correct_u_0"] = corr_u
+ logging_output[f"count_u_0"] = count_u
+
+ return loss, sample_size, logging_output
diff --git a/egs/librispeech/SSL/zipformer/joiner.py b/egs/librispeech/SSL/zipformer/joiner.py
new file mode 120000
index 000000000..aa3362cda
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/joiner.py
@@ -0,0 +1 @@
+../../ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py
new file mode 100644
index 000000000..46a968b69
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/model.py
@@ -0,0 +1,344 @@
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Zengwei Yao,
+# Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from scaling import ScaledLinear
+
+from icefall.utils import add_sos
+
+
+class AsrModel(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder: Optional[nn.Module] = None,
+ joiner: Optional[nn.Module] = None,
+ encoder_dim: int = 768,
+ decoder_dim: int = 512,
+ vocab_size: int = 500,
+ use_transducer: bool = True,
+ use_ctc: bool = False,
+ ):
+ """A joint CTC & Transducer ASR model.
+
+ - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
+ - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
+ - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
+
+ Args:
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ inputs: `x` of (N, T, encoder_dim).
+ It returns two tensors: `logits` of shape (N, T, encoder_dim) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, decoder_dim).
+ It should contain one attribute: `blank_id`.
+ It is used when use_transducer is True.
+ joiner:
+ It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+ Its output shape is (N, T, U, vocab_size). Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ It is used when use_transducer is True.
+ use_transducer:
+ Whether use transducer head. Default: True.
+ use_ctc:
+ Whether use CTC head. Default: False.
+ """
+ super().__init__()
+
+ assert (
+ use_transducer or use_ctc
+ ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
+
+ self.encoder = encoder
+
+ self.use_transducer = use_transducer
+ if use_transducer:
+ # Modules for Transducer head
+ assert decoder is not None
+ assert hasattr(decoder, "blank_id")
+ assert joiner is not None
+
+ self.decoder = decoder
+ self.joiner = joiner
+
+ self.simple_am_proj = ScaledLinear(
+ encoder_dim, vocab_size, initial_scale=0.25
+ )
+ self.simple_lm_proj = ScaledLinear(
+ decoder_dim, vocab_size, initial_scale=0.25
+ )
+ else:
+ assert decoder is None
+ assert joiner is None
+
+ self.use_ctc = use_ctc
+ if use_ctc:
+ # Modules for CTC head
+ self.ctc_output = nn.Sequential(
+ nn.Dropout(p=0.1),
+ nn.Linear(encoder_dim, vocab_size),
+ nn.LogSoftmax(dim=-1),
+ )
+
+ def forward_encoder(
+ self,
+ x: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute encoder outputs.
+ Args:
+ x:
+ A 2-D tensor of shape (N, T).
+
+ Returns:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ """
+ if padding_mask is None:
+ padding_mask = torch.zeros_like(x, dtype=torch.bool)
+
+ encoder_out, padding_mask = self.encoder.extract_features(
+ source=x,
+ padding_mask=padding_mask,
+ mask=self.encoder.training,
+ )
+ encoder_out_lens = torch.sum(~padding_mask, dim=1)
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+
+ return encoder_out, encoder_out_lens
+
+ def forward_ctc(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ targets: torch.Tensor,
+ target_lengths: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+ Args:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ targets:
+ Target Tensor of shape (sum(target_lengths)). The targets are assumed
+ to be un-padded and concatenated within 1 dimension.
+ """
+ # Compute CTC log-prob
+ ctc_output = self.ctc_output(encoder_out) # (N, T, C)
+
+ ctc_loss = torch.nn.functional.ctc_loss(
+ log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
+ targets=targets,
+ input_lengths=encoder_out_lens,
+ target_lengths=target_lengths,
+ reduction="sum",
+ )
+ return ctc_loss
+
+ def forward_transducer(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ y_lens: torch.Tensor,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute Transducer loss.
+ Args:
+ encoder_out:
+ Encoder output, of shape (N, T, C).
+ encoder_out_lens:
+ Encoder output lengths, of shape (N,).
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ """
+ # Now for the decoder, i.e., the prediction network
+ blank_id = self.decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ # sos_y_padded: [B, S + 1], start with SOS.
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+ # decoder_out: [B, S + 1, decoder_dim]
+ decoder_out = self.decoder(sos_y_padded)
+
+ # Note: y does not start with SOS
+ # y_padded : [B, S]
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros(
+ (encoder_out.size(0), 4),
+ dtype=torch.int64,
+ device=encoder_out.device,
+ )
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = encoder_out_lens
+
+ lm = self.simple_lm_proj(decoder_out)
+ am = self.simple_am_proj(encoder_out)
+
+ # if self.training and random.random() < 0.25:
+ # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
+ # if self.training and random.random() < 0.25:
+ # am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+ lm=lm.float(),
+ am=am.float(),
+ symbols=y_padded,
+ termination_symbol=blank_id,
+ lm_only_scale=lm_scale,
+ am_only_scale=am_scale,
+ boundary=boundary,
+ reduction="sum",
+ return_grad=True,
+ )
+
+ # ranges : [B, T, prune_range]
+ ranges = k2.get_rnnt_prune_ranges(
+ px_grad=px_grad,
+ py_grad=py_grad,
+ boundary=boundary,
+ s_range=prune_range,
+ )
+
+ # am_pruned : [B, T, prune_range, encoder_dim]
+ # lm_pruned : [B, T, prune_range, decoder_dim]
+ am_pruned, lm_pruned = k2.do_rnnt_pruning(
+ am=self.joiner.encoder_proj(encoder_out),
+ lm=self.joiner.decoder_proj(decoder_out),
+ ranges=ranges,
+ )
+
+ # logits : [B, T, prune_range, vocab_size]
+
+ # project_input=False since we applied the decoder's input projections
+ # prior to do_rnnt_pruning (this is an optimization for speed).
+ logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ pruned_loss = k2.rnnt_loss_pruned(
+ logits=logits.float(),
+ symbols=y_padded,
+ ranges=ranges,
+ termination_symbol=blank_id,
+ boundary=boundary,
+ reduction="sum",
+ )
+
+ return simple_loss, pruned_loss
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ y: k2.RaggedTensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (N, T).
+ y:
+ A ragged tensor with 2 axes [utt][label]. It contains labels of each
+ utterance.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ Returns:
+ Return the transducer losses and CTC loss,
+ in form of (simple_loss, pruned_loss, ctc_loss)
+
+ Note:
+ Regarding am_scale & lm_scale, it will make the loss-function one of
+ the form:
+ lm_scale * lm_probs + am_scale * am_probs +
+ (1-lm_scale-am_scale) * combined_probs
+ """
+ assert x.ndim == 2, x.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == y.dim0, (x.shape, y.dim0)
+
+ # Compute encoder outputs
+ encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
+
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ if self.use_transducer:
+ # Compute transducer loss
+ simple_loss, pruned_loss = self.forward_transducer(
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ y=y.to(x.device),
+ y_lens=y_lens,
+ prune_range=prune_range,
+ am_scale=am_scale,
+ lm_scale=lm_scale,
+ )
+ else:
+ simple_loss = torch.empty(0)
+ pruned_loss = torch.empty(0)
+
+ if self.use_ctc:
+ # Compute CTC loss
+ targets = y.values
+ ctc_loss = self.forward_ctc(
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ targets=targets,
+ target_lengths=y_lens,
+ )
+ else:
+ ctc_loss = torch.empty(0)
+
+ return simple_loss, pruned_loss, ctc_loss, encoder_out_lens
diff --git a/egs/librispeech/SSL/zipformer/optim.py b/egs/librispeech/SSL/zipformer/optim.py
new file mode 120000
index 000000000..56b827b8a
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/optim.py
@@ -0,0 +1 @@
+../../ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py
new file mode 100644
index 000000000..5f547e0b8
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/pretrain.py
@@ -0,0 +1,1380 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For hubert model pretraining:
+./zipformer/pretrain.py \
+ --world-size 8 \
+ --num-epochs 400 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir hubert/exp \
+ --full-libri 1 \
+ --max-duration 87.5 \
+ --accum-grad 4
+"""
+
+
+import argparse
+import copy
+import logging
+import sys
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from hubert_ce import HubertModel
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from ssl_datamodule import LibriSpeechDataModule
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * params.accum_grad
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ # hubert parameters
+ parser.add_argument(
+ "--label-rate",
+ type=float,
+ default=50,
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=float,
+ default=16000,
+ )
+
+ parser.add_argument(
+ "--extractor-mode",
+ type=str,
+ default="default",
+ help="""mode for feature extractor, should in EXTRACTOR_MODE_CHOICES. default has a single group
+ norm with d groups in the first conv block, whereas layer_norm
+ has layer norms in every block (meant to use with normalize=True)""",
+ )
+
+ parser.add_argument(
+ "--conv-feature-layers",
+ type=str,
+ default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
+ )
+
+ parser.add_argument(
+ "--conv-bias", type=bool, default=False, help="include bias in conv encoder"
+ )
+
+ parser.add_argument(
+ "--feature-grad-mult",
+ type=float,
+ default=1.0,
+ help="multiply feature extractor var grads by this",
+ )
+
+ # masking
+ parser.add_argument("--mask-length", type=int, default=10, help="mask_length")
+
+ parser.add_argument(
+ "--mask-prob",
+ type=float,
+ default=0.65,
+ help="probability of replacing a token with mask",
+ )
+
+ parser.add_argument(
+ "--mask-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length",
+ )
+
+ parser.add_argument(
+ "--mask-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions),see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # channel masking
+ parser.add_argument(
+ "--mask-channel-length",
+ type=int,
+ default=10,
+ help="length of the mask for features (channels)",
+ )
+
+ parser.add_argument(
+ "--mask-channel-prob",
+ type=float,
+ default=0.0,
+ help="probability of replacing a feature with 0",
+ )
+
+ parser.add_argument(
+ "--mask-channel-selection",
+ type=str,
+ choices=["static", "uniform", "normal", "poisson"],
+ default="static",
+ help="how to choose mask length for channel masking",
+ )
+
+ parser.add_argument(
+ "--mask-channel-other",
+ type=float,
+ default=0,
+ help="secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh",
+ )
+
+ parser.add_argument(
+ "--no-mask-channel-overlap",
+ type=bool,
+ default=False,
+ help="whether to allow channel masks to overlap",
+ )
+
+ parser.add_argument(
+ "--mask-channel-min-space",
+ type=int,
+ default=1,
+ help="min space between spans (if no overlap is enabled)",
+ )
+
+ # loss computation
+ parser.add_argument(
+ "--skip-masked",
+ type=bool,
+ default=False,
+ help="skip computing losses over masked frames",
+ )
+
+ parser.add_argument(
+ "--skip-nomask",
+ type=bool,
+ default=False,
+ help="skip computing losses over unmasked frames",
+ )
+
+ parser.add_argument(
+ "--checkpoint-activations",
+ type=bool,
+ default=False,
+ help="recompute activations and save memory for extra compute",
+ )
+
+ parser.add_argument(
+ "--pred-masked-weight",
+ type=float,
+ default=1,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--pred-nomask-weight",
+ type=float,
+ default=0,
+ help="weight for masked part in ssl loss",
+ )
+
+ parser.add_argument(
+ "--loss-weights",
+ type=float,
+ nargs="*",
+ default=[10],
+ help="weight for masked part in ssl loss",
+ )
+
+ # FP16 optimization
+ parser.add_argument(
+ "--required-seq-len-multiple",
+ type=int,
+ default=2,
+ help="pad the input to encoder such that the sequence length is divisible by multiple",
+ )
+
+ parser.add_argument(
+ "--attn-type", type=str, default="", help="if espnet use ESPNET MHA"
+ )
+
+ parser.add_argument(
+ "--pos-enc-type",
+ type=str,
+ default="abs",
+ help="Positional encoding type to use in conformer",
+ )
+
+ parser.add_argument(
+ "--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
+ )
+
+ parser.add_argument(
+ "--dropout-input",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the input (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--dropout-features",
+ type=float,
+ default=0.0,
+ help="dropout to apply to the features (after feat extr)",
+ )
+
+ parser.add_argument(
+ "--num-classes",
+ type=int,
+ nargs="*",
+ default=[504],
+ help="""num class, a little larger than the number of cluster,
+ the largest is for padding,
+ and the value should be the multiple of 4, for faster computation""",
+ )
+
+ parser.add_argument(
+ "--untie-final-proj",
+ type=bool,
+ default=False,
+ help="use separate projection for each target",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=400,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--warmup-batches",
+ type=float,
+ default=5000,
+ help="Eden warmup steps",
+ )
+
+ parser.add_argument(
+ "--warmup-start",
+ type=float,
+ default=0,
+ help="Eden warmup start learning rate",
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--sanity-check",
+ type=str2bool,
+ default=False,
+ help="Check if any of the batches in epoch 1 would cause OOM.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=100000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--accum-grad",
+ type=int,
+ default=4,
+ help="""update gradient when batch_idx_train % accum_grad == 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--max-keep-size",
+ type=int,
+ default=sys.maxsize,
+ help="exclude sample longer than this.",
+ )
+
+ parser.add_argument(
+ "--min-keep-size",
+ type=float,
+ default=32000,
+ help="exclude sample longer less than this.",
+ )
+
+ parser.add_argument(
+ "--max-sample-size",
+ type=float,
+ default=250000,
+ help="max sample size to crop to for batching.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of updates happen to the model so far across
+ epochs.
+
+ - sub_batch_idx_train: It contains number of batch trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "sub_batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ model = HubertModel(params)
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ audio = batch["audio"].to(device)
+ padding_mask = batch["padding_mask"].to(device)
+ kmeans = batch["kmeans"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss, num_masked_tokens, logging_output = model(
+ source=audio, target_list=[kmeans], padding_mask=padding_mask
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = num_masked_tokens
+ for item in logging_output:
+ info[item] = logging_output[item]
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for sub_batch_idx, batch in enumerate(train_dl):
+ params.sub_batch_idx_train += 1
+ batch_idx = sub_batch_idx // params.accum_grad
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ batch_size = batch["kmeans"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss / params.accum_grad).backward()
+
+ if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
+ params.batch_idx_train += 1
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ else:
+ continue
+
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ if batch_idx % params.accum_grad != params.accum_grad - 1:
+ optimizer.zero_grad()
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(
+ optimizer,
+ params.lr_batches,
+ params.lr_epochs,
+ params.warmup_batches,
+ params.warmup_start,
+ )
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechDataModule(args)
+
+ train_cuts = (
+ librispeech.train_all_shuf_cuts()
+ if params.full_libri
+ else librispeech.train_clean_100_cuts()
+ )
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if (
+ c.duration < params.min_keep_size / params.sample_rate
+ or c.duration > params.max_keep_size / params.sample_rate
+ ):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ sampler_state_dict=sampler_state_dict,
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ # valid_cuts += librispeech.dev_other_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
+
+ valid_dl = librispeech.valid_dataloaders(
+ valid_cuts,
+ max_sample_size=params.max_sample_size,
+ sample_rate=params.sample_rate,
+ label_rate=params.label_rate,
+ random_crop=params.random_crop,
+ pad_audio=False,
+ num_classes=params.num_classes,
+ do_normalize=params.do_normalize,
+ )
+
+ if params.sanity_check and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `dataset.HubertDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ audio = batch["audio"]
+ logging.info(f"audio shape: {audio.shape}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/SSL/zipformer/scaling.py b/egs/librispeech/SSL/zipformer/scaling.py
new file mode 120000
index 000000000..e30bd99de
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/scaling.py
@@ -0,0 +1 @@
+../../ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/ssl_datamodule.py b/egs/librispeech/SSL/zipformer/ssl_datamodule.py
new file mode 120000
index 000000000..9f5085e3a
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/ssl_datamodule.py
@@ -0,0 +1 @@
+../hubert/ssl_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/SSL/zipformer/utils.py b/egs/librispeech/SSL/zipformer/utils.py
new file mode 100644
index 000000000..748d3c96e
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/utils.py
@@ -0,0 +1,337 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+from typing import Callable, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def relu_squared(x: torch.Tensor):
+ return F.relu(x).pow(2)
+
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+ )
+
+
+def is_xla_tensor(tensor):
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
+
+
+def index_put(tensor, indices, value):
+ if is_xla_tensor(tensor):
+ for _ in range(indices.dim(), tensor.dim()):
+ indices = indices.unsqueeze(-1)
+ if indices.size(-1) < tensor.size(-1):
+ indices = indices.expand_as(tensor)
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
+ else:
+ tensor[indices] = value
+ return tensor
+
+
+def pad_to_multiple(x, multiple, dim=-1, value=0):
+ # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
+ if x is None:
+ return None, 0
+ tsz = x.size(dim)
+ m = tsz / multiple
+ remainder = math.ceil(m) * multiple - tsz
+ if m.is_integer():
+ return x, 0
+ pad_offset = (0,) * (-1 - dim) * 2
+
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str) -> Callable:
+ """Returns the activation function corresponding to `activation`"""
+ if activation == "relu":
+ return F.relu
+ elif activation == "relu_squared":
+ return relu_squared
+ elif activation == "gelu":
+ return gelu
+ elif activation == "gelu_fast":
+ return gelu_accurate
+ elif activation == "gelu_accurate":
+ return gelu_accurate
+ elif activation == "tanh":
+ return torch.tanh
+ elif activation == "linear":
+ return lambda x: x
+ elif activation == "swish":
+ return torch.nn.SiLU
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+class SamePad(nn.Module):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class SamePad2d(nn.Module):
+ def __init__(self, kernel_size):
+ super().__init__()
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ assert len(x.size()) == 4
+ if self.remove > 0:
+ x = x[:, :, : -self.remove, : -self.remove]
+ return x
+
+
+class TransposeLast(nn.Module):
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
+ super().__init__()
+ self.deconstruct_idx = deconstruct_idx
+ self.tranpose_dim = tranpose_dim
+
+ def forward(self, x):
+ if self.deconstruct_idx is not None:
+ x = x[self.deconstruct_idx]
+ return x.transpose(self.tranpose_dim, -1)
+
+
+try:
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
+
+ has_fused_layernorm = True
+
+ class FusedLayerNorm(_FusedLayerNorm):
+ @torch.jit.unused
+ def forward(self, x):
+ if not x.is_cuda:
+ return super().forward(x)
+ else:
+ with torch.cuda.device(x.device):
+ return super().forward(x)
+
+except ImportError:
+ has_fused_layernorm = False
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ export = True
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.group_norm(
+ input.float(),
+ self.num_groups,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+def softmax(x, dim: int, onnx_trace: bool = False):
+ if onnx_trace:
+ return F.softmax(x.float(), dim=dim)
+ else:
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def quant_noise(module, p, block_size):
+ """
+ Wraps modules and applies quantization noise to the weights for
+ subsequent quantization with Iterative Product Quantization as
+ described in "Training with Quantization Noise for Extreme Model Compression"
+
+ Args:
+ - module: nn.Module
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights,
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper
+ which consists in randomly dropping blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ assert (
+ module.weight.size(1) % block_size == 0
+ ), "Input features must be a multiple of block sizes"
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ assert (
+ module.in_channels % block_size == 0
+ ), "Input channels must be a multiple of block sizes"
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = torch.zeros(
+ in_features // block_size * out_features, device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(
+ weight.size(0), weight.size(1), device=weight.device
+ )
+ mask.bernoulli_(p)
+ mask = (
+ mask.unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ )
+
+ # scale weights and apply mask
+ mask = mask.to(
+ torch.bool
+ ) # x.bool() is not currently supported in TorchScript
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class FairseqDropout(nn.Module):
+ def __init__(self, p, module_name=None):
+ super().__init__()
+ self.p = p
+ self.module_name = module_name
+ self.apply_during_inference = False
+
+ def forward(self, x, inplace: bool = False):
+ if self.p > 0 and (self.training or self.apply_during_inference):
+ return F.dropout(x, p=self.p, training=True, inplace=inplace)
+ else:
+ return x
+
+ def make_generation_fast_(
+ self,
+ name: str,
+ retain_dropout: bool = False,
+ retain_dropout_modules: Optional[List[str]] = None,
+ **kwargs
+ ):
+ if retain_dropout:
+ if retain_dropout_modules is not None and self.module_name is None:
+ pass
+ elif (
+ retain_dropout_modules is None # if None, apply to all modules
+ or self.module_name in retain_dropout_modules
+ ):
+ self.apply_during_inference = True
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
diff --git a/egs/librispeech/SSL/zipformer/wav2vec2_module.py b/egs/librispeech/SSL/zipformer/wav2vec2_module.py
new file mode 100644
index 000000000..ab5ca005f
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/wav2vec2_module.py
@@ -0,0 +1,108 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+from typing import List, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
+
+
+class ConvFeatureExtractionModel(nn.Module):
+ def __init__(
+ self,
+ conv_layers: List[Tuple[int, int, int]],
+ dropout: float = 0.0,
+ mode: str = "default",
+ conv_bias: bool = False,
+ ):
+ super().__init__()
+
+ assert mode in {"default", "layer_norm"}
+
+ def block(
+ n_in,
+ n_out,
+ k,
+ stride,
+ is_layer_norm=False,
+ is_group_norm=False,
+ conv_bias=False,
+ ):
+ def make_conv():
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+ nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert (
+ is_layer_norm and is_group_norm
+ ) == False, "layer norm and group norm are exclusive"
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.Sequential(
+ TransposeLast(),
+ Fp32LayerNorm(dim, elementwise_affine=True),
+ TransposeLast(),
+ ),
+ nn.GELU(),
+ )
+ elif is_group_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ Fp32GroupNorm(dim, dim, affine=True),
+ nn.GELU(),
+ )
+ else:
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode == "layer_norm",
+ is_group_norm=mode == "default" and i == 0,
+ conv_bias=conv_bias,
+ )
+ )
+ in_d = dim
+
+ def forward(self, x):
+ # BxT -> BxCxT
+ x = x.unsqueeze(1)
+
+ for conv in self.conv_layers:
+ x = conv(x)
+
+ return x
diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py
new file mode 100644
index 000000000..e9eff3357
--- /dev/null
+++ b/egs/librispeech/SSL/zipformer/zipformer.py
@@ -0,0 +1,2438 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from encoder_interface import EncoderInterface
+from scaling import (
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+ ActivationDropoutAndLinear,
+ Balancer,
+ BiasNorm,
+ ChunkCausalDepthwiseConv1d,
+ Dropout2,
+ FloatLike,
+ ScheduledFloat,
+ Whiten,
+ convert_num_channels,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+from torch import Tensor, nn
+
+
+class Zipformer2(EncoderInterface):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+ as downsampling_factor if they are single ints or one-element tuples. The length of
+ downsampling_factor defines the number of stacks.
+
+ output_downsampling_factor (int): how much to downsample at the output. Note:
+ we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
+ You should probably leave this at 2.
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+ encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
+ the encoder stacks for purposes of per-frame dropout (recommend 256 for
+ now).
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+ attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+ e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ causal (bool): if True, support chunkwise causal convolution. This should
+ not hurt WER as no modeling power is lost, but the convolution modules will be
+ slightly slower and use more memory. Enables use of the chunk_size and
+ left_context_chunks options in forward(), which simulates streaming
+ decoding.
+ chunk_size: (list of int): only set this to other than [-1] if causal;
+ the chunk size will be randomly chosen from this list. -1 means no chunking.
+ left_context_frames: (list of int): determines the number of left-
+ context chunks for causal training; will be rounded to a number of
+ chunks. Must not be less than cnn_module_kernel (after factoring in
+ rounding and downsampling); an error will be thrown if this is violated.
+ """
+
+ def __init__(
+ self,
+ output_downsampling_factor: int = 2,
+ downsampling_factor: Tuple[int] = (2, 4),
+ encoder_dim: Union[int, Tuple[int]] = 384,
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
+ query_head_dim: Union[int, Tuple[int]] = 24,
+ pos_head_dim: Union[int, Tuple[int]] = 4,
+ value_head_dim: Union[int, Tuple[int]] = 12,
+ num_heads: Union[int, Tuple[int]] = 8,
+ feedforward_dim: Union[int, Tuple[int]] = 1536,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ causal: bool = False,
+ chunk_size: Tuple[int] = [-1],
+ left_context_frames: Tuple[int] = [-1],
+ ) -> None:
+ super(Zipformer2, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
+ as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ self.output_downsampling_factor = output_downsampling_factor # int
+ self.downsampling_factor = downsampling_factor # tuple
+ self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
+ self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
+ encoder_unmasked_dim
+ ) # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
+ self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
+ pos_head_dim = _to_tuple(pos_head_dim)
+ self.num_heads = num_heads = _to_tuple(num_heads)
+ feedforward_dim = _to_tuple(feedforward_dim)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+
+ self.causal = causal
+ self.chunk_size = chunk_size
+ self.left_context_frames = left_context_frames
+
+ for u, d in zip(encoder_unmasked_dim, encoder_dim):
+ assert u <= d
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim[i],
+ pos_dim=pos_dim,
+ num_heads=num_heads[i],
+ query_head_dim=query_head_dim[i],
+ pos_head_dim=pos_head_dim[i],
+ value_head_dim=value_head_dim[i],
+ feedforward_dim=feedforward_dim[i],
+ dropout=dropout,
+ cnn_module_kernel=cnn_module_kernel[i],
+ causal=causal,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ pos_dim=pos_dim,
+ dropout=dropout,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim[i],
+ downsample=downsampling_factor[i],
+ dropout=dropout,
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+
+ self.downsample_output = SimpleDownsample(
+ max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
+ )
+
+ def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
+ """
+ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
+ randomized feature masks, one per encoder.
+ On e.g. 15% of frames, these masks will zero out all enocder dims larger than
+ some supplied number, e.g. >256, so in effect on those frames we are using
+ a smaller encoer dim.
+
+ We generate the random masks at this level because we want the 2 masks to 'agree'
+ all the way up the encoder stack. This will mean that the 1st mask will have
+ mask values repeated self.zipformer_subsampling_factor times.
+
+ Args:
+ x: the embeddings (needed for the shape and dtype and device), of shape
+ (1, batch_size, encoder_dims0)
+ """
+ num_encoders = len(self.encoder_dim)
+ if not self.training:
+ return [1.0] * num_encoders
+
+ (num_frames0, batch_size, _encoder_dims0) = x.shape
+
+ assert self.encoder_dim[0] == _encoder_dims0, (
+ self.encoder_dim[0],
+ _encoder_dims0,
+ )
+
+ feature_mask_dropout_prob = 0.125
+
+ # mask1 shape: (1, batch_size, 1)
+ mask1 = (
+ torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
+ ).to(x.dtype)
+
+ # mask2 has additional sequences masked, about twice the number.
+ mask2 = torch.logical_and(
+ mask1,
+ (
+ torch.rand(1, batch_size, 1, device=x.device)
+ > feature_mask_dropout_prob
+ ).to(x.dtype),
+ )
+
+ # dim: (1, batch_size, 2)
+ mask = torch.cat((mask1, mask2), dim=-1)
+
+ feature_masks = []
+ for i in range(num_encoders):
+ channels = self.encoder_dim[i]
+ feature_mask = torch.ones(
+ 1, batch_size, channels, dtype=x.dtype, device=x.device
+ )
+ u1 = self.encoder_unmasked_dim[i]
+ u2 = u1 + (channels - u1) // 2
+
+ feature_mask[:, :, u1:u2] *= mask[..., 0:1]
+ feature_mask[:, :, u2:] *= mask[..., 1:2]
+
+ feature_masks.append(feature_mask)
+
+ return feature_masks
+
+ def get_chunk_info(self) -> Tuple[int, int]:
+ """
+ Returns chunk_size and left_context_chunks.
+ """
+ if not self.causal:
+ return -1, -1
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.chunk_size) == 1, self.chunk_size
+ chunk_size = self.chunk_size[0]
+ else:
+ chunk_size = random.choice(self.chunk_size)
+
+ if chunk_size == -1:
+ left_context_chunks = -1
+ else:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.left_context_frames) == 1, self.left_context_frames
+ left_context_frames = self.left_context_frames[0]
+ else:
+ left_context_frames = random.choice(self.left_context_frames)
+ # Note: in Python, -1 // n == -1 for n > 0
+ left_context_chunks = left_context_frames // chunk_size
+ if left_context_chunks == 0:
+ left_context_chunks = 1
+
+ return chunk_size, left_context_chunks
+
+ def forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ """
+ outputs = []
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ feature_masks = [1.0] * len(self.encoder_dim)
+ else:
+ feature_masks = self.get_feature_masks(x)
+
+ chunk_size, left_context_chunks = self.get_chunk_info()
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # Not support exporting a model for simulating streaming decoding
+ attn_mask = None
+ else:
+ attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
+
+ for i, module in enumerate(self.encoders):
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x = module(
+ x,
+ chunk_size=chunk_size,
+ feature_mask=feature_masks[i],
+ src_key_padding_mask=(
+ None
+ if src_key_padding_mask is None
+ else src_key_padding_mask[..., ::ds]
+ ),
+ attn_mask=attn_mask,
+ )
+ outputs.append(x)
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ # x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ # assert self.output_downsampling_factor == 2, self.output_downsampling_factor
+ # if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # lengths = (x_lens + 1) // 2
+ # else:
+ # with warnings.catch_warnings():
+ # warnings.simplefilter("ignore")
+ # lengths = (x_lens + 1) // 2
+
+ # return x, lengths
+ return x, x_lens
+
+ def _get_attn_mask(
+ self, x: Tensor, chunk_size: int, left_context_chunks: int
+ ) -> Optional[Tensor]:
+ """
+ Return None if chunk_size == -1, else return attention mask of shape
+ (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
+ means a masked position.
+ Args:
+ x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
+ chunk_size: chunk size, must divide
+ """
+ if chunk_size <= 0:
+ return None
+ assert all(chunk_size % d == 0 for d in self.downsampling_factor)
+ if left_context_chunks >= 0:
+ num_encoders = len(self.encoder_dim)
+ assert all(
+ chunk_size * left_context_chunks
+ >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
+ for i in range(num_encoders)
+ )
+ else:
+ left_context_chunks = 1000000
+
+ seq_len = x.shape[0]
+
+ # t is frame index, shape (seq_len,)
+ t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
+ # c is chunk index for each frame, shape (seq_len,)
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ c = t // chunk_size
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ c = t // chunk_size
+ src_c = c
+ tgt_c = c.unsqueeze(-1)
+
+ attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
+ if __name__ == "__main__":
+ logging.info(f"attn_mask = {attn_mask}")
+ return attn_mask
+
+ def _get_full_dim_output(self, outputs: List[Tensor]):
+ num_encoders = len(self.encoder_dim)
+ assert len(outputs) == num_encoders
+ output_dim = max(self.encoder_dim)
+ output_pieces = [outputs[-1]]
+ cur_dim = self.encoder_dim[-1]
+ for i in range(num_encoders - 2, -1, -1):
+ d = self.encoder_dim[i]
+ if d > cur_dim:
+ this_output = outputs[i]
+ output_pieces.append(this_output[..., cur_dim:d])
+ cur_dim = d
+ assert cur_dim == output_dim
+ return torch.cat(output_pieces, dim=-1)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ states: List[Tensor],
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ states: list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ - updated states
+ """
+ outputs = []
+ new_states = []
+ layer_offset = 0
+
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x, new_layer_states = module.streaming_forward(
+ x,
+ states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
+ left_context_len=self.left_context_frames[0] // ds,
+ src_key_padding_mask=src_key_padding_mask[..., ::ds],
+ )
+ layer_offset += num_layers
+ outputs.append(x)
+ new_states += new_layer_states
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[Tensor]:
+ """Get initial states.
+
+ A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ """
+ states = []
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ embed_dim = self.encoder_dim[i]
+ ds = self.downsampling_factor[i]
+ num_heads = self.num_heads[i]
+ key_dim = self.query_head_dim[i] * num_heads
+ value_dim = self.value_head_dim[i] * num_heads
+ downsample_left = self.left_context_frames[0] // ds
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ conv_left_pad = self.cnn_module_kernel[i] // 2
+ for layer in range(num_layers):
+ cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
+ device
+ )
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, downsample_left, nonlin_attn_head_dim
+ ).to(device)
+ cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ return states
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+def _balancer_schedule(min_prob: float):
+ return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ causal: bool = False,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass; see its default values.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ )
+
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 3) // 4, dropout
+ )
+
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 5) // 4, dropout
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.conv_module1 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ self.conv_module2 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ # TODO: remove it
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ chunk_size: int = -1,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif not self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ return src
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ cached_nonlin_attn: Tensor,
+ cached_val1: Tensor,
+ cached_val2: Tensor,
+ cached_conv1: Tensor,
+ cached_conv2: Tensor,
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Pass the input through the encoder layer in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
+ (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ cached_val1: cached left context for the first attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_val2: cached left context for the second attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_conv1: cached left context for the first convolution module,
+ of shape (batch_size, channels, left_pad)
+ cached_conv2: cached left context for the second convolution module,
+ of shape (batch_size, channels, left_pad)
+ left_context_len: number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - x, with the same shape as src
+ - updated cached_key
+ - updated cached_nonlin_attn
+ - updated cached_val1
+ - updated cached_val2
+ - updated cached_conv1
+ - updated cached_conv2
+ """
+ src_orig = src
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights, cached_key = self.self_attn_weights.streaming_forward(
+ src,
+ pos_emb=pos_emb,
+ cached_key=cached_key,
+ left_context_len=left_context_len,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
+ src,
+ attn_weights[0:1],
+ cached_x=cached_nonlin_attn,
+ left_context_len=left_context_len,
+ )
+ src = src + na
+
+ self_attn, cached_val1 = self.self_attn1.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val1,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv1 = self.conv_module1.streaming_forward(
+ src,
+ cache=cached_conv1,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward2(src)
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn, cached_val2 = self.self_attn2.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val2,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv2 = self.conv_module2.streaming_forward(
+ src,
+ cache=cached_conv2,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward3(src)
+
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ return (
+ src,
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ )
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ pos_dim: int,
+ dropout: float,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ output = src
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ chunk_size=chunk_size,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ return output
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ pos_emb = self.encoder_pos(src, left_context_len)
+ output = src
+
+ new_states = []
+ for i, mod in enumerate(self.layers):
+ (
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ) = states[i * 6 : (i + 1) * 6]
+ (
+ output,
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ) = mod.streaming_forward(
+ output,
+ pos_emb,
+ cached_key=cached_key,
+ cached_nonlin_attn=cached_nonlin_attn,
+ cached_val1=cached_val1,
+ cached_val2=cached_val2,
+ cached_conv1=cached_conv1,
+ cached_conv2=cached_conv2,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ new_states += [
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ]
+
+ return output, new_states
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+ layer-skipping. The bypass is limited during early stages of training to be close to
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
+ force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 correponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
+ # on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+ after convolutional downsampling, and then upsampled again at the output, and combined
+ with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(
+ self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
+ ):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(dim, downsample, dropout)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(dim, downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+
+ src = self.encoder(
+ src,
+ chunk_size=chunk_size // ds,
+ feature_mask=feature_mask,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Downsample, go through encoder, upsample, in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
+ True means masked position. May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ src_orig = src
+ src = self.downsample(src)
+
+ src, new_states = self.encoder.streaming_forward(
+ src,
+ states=states,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src), new_states
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum, and a projection..
+ """
+
+ def __init__(self, channels: int, downsample: int, dropout: FloatLike):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+ self.dropout = copy.deepcopy(dropout)
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that mostly just repeats the input, but
+ also adds a position-specific bias.
+ """
+
+ def __init__(self, num_channels: int, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
+ the important information about the relative position in a relatively small number of dimensions.
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+ make very little difference to the embedding. Such differences were potentially important
+ when encoding absolute position, but not important when encoding relative position because there
+ is now no need to compare two large offsets with each other.
+
+ Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
+ using the atan() function, before doing the fourier transform of that fixed interval. The
+ atan() function would compress the "long tails" too small,
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+ function to compress large offsets to a smaller range before applying atan().
+ Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+ # for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+ # but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+ # is important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position encoding.
+ Various other modules consume the resulting attention weights: see, for example, the
+ SimpleAttention module which allows you to compute conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ self.in_proj = ScaledLinear(
+ embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be suffixient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnosics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ left_context_len: int,
+ key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ left_context_len: number of left context frames.
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+
+ Returns:
+ - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ - updated cached attention key tensor of left context.
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ # Pad cached left contexts
+ assert cached_key.shape[0] == left_context_len, (
+ cached_key.shape[0],
+ left_context_len,
+ )
+ k = torch.cat([cached_key, k], dim=0)
+ # Update cached left contexts
+ cached_key = k[-left_context_len:, ...]
+
+ # The length of key
+ k_len = k.shape[0]
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1 + left_context_len
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(k_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, k_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ assert attn_scores.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ k_len,
+ ), attn_scores.shape
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ attn_weights = attn_scores.softmax(dim=-1)
+
+ return attn_weights, cached_key
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.info(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed attention
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ ) -> None:
+ super().__init__()
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_val: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ cached_val: cached attention value tensor of left context,
+ of shape (left_context_len, batch_size, value_dim)
+ left_context_len: number of left context frames.
+
+ Returns:
+ - attention weighted output, a tensor with the same shape as x.
+ - updated cached attention value tensor of left context.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ seq_len2 = seq_len + left_context_len
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+
+ # Pad cached left contexts
+ assert cached_val.shape[0] == left_context_len, (
+ cached_val.shape[0],
+ left_context_len,
+ )
+ x = torch.cat([cached_val, x], dim=0)
+ # Update cached left contexts
+ cached_val = x[-left_context_len:, ...]
+
+ x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+
+ return x, cached_val
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in Zipformer2 model."""
+
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
+ one after the attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
+ # before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_x: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ cached_x: left context, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ left_context_len: number of left context frames.
+ Returns:
+ - a Tensor with the same shape as x
+ - updated left context with same shape as cached_x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = x * s
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ left_context_len + seq_len,
+ )
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+
+ # Pad cached tensor
+ assert cached_x.shape[2] == left_context_len, (
+ cached_x.shape[2],
+ left_context_len,
+ )
+ x_pad = torch.cat([cached_x, x], dim=2)
+ # Update cached tensor
+ cached_x = x_pad[:, :, -left_context_len:, :]
+
+ x = torch.matmul(attn_weights, x_pad)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ x = x * y
+
+ x = self.out_proj(x)
+ return x, cached_x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ causal: bool,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+ self.causal = causal
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = (
+ ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
+ if causal
+ else nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ chunk_size: int = -1,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ if (
+ not torch.jit.is_scripting()
+ and not torch.jit.is_tracing()
+ and chunk_size >= 0
+ ):
+ # Not support exporting a model for simulated streaming decoding
+ assert (
+ self.causal
+ ), "Must initialize model with causal=True if you use chunk_size"
+ x = self.depthwise_conv(x, chunk_size=chunk_size)
+ else:
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module in streaming forward mode.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ cache: cached left context for depthwise_conv of shape
+ (#batch, channels, left_pad)
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ - Output tensor (#time, batch, channels).
+ - Updated cache (#batch, channels, left_pad)
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.sigmoid(s)
+ x = x * s
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
+
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x, cache
+
+
+class ScalarMultiply(nn.Module):
+ def __init__(self, scale: float):
+ super().__init__()
+ self.scale = scale
+
+ def forward(self, x):
+ return x * self.scale
+
+
+def _test_zipformer_main(causal: bool = False):
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+
+ c = Zipformer2(
+ encoder_dim=(64, 96),
+ encoder_unmasked_dim=(48, 64),
+ num_heads=(4, 4),
+ causal=causal,
+ chunk_size=(4,) if causal else (-1,),
+ left_context_frames=(64,),
+ )
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f[0].sum().backward()
+ c.eval()
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f # to remove flake8 warnings
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_zipformer_main(False)
+ _test_zipformer_main(True)
diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
index 2f8e658c5..e1a29bd9c 100644
--- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
+++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
@@ -227,6 +227,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py
new file mode 100755
index 000000000..b6b1cb020
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+# Fangjun Kuang,
+# Quandong Wang)
+# 2023 Johns Hopkins University (Author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.decode import get_lattice, one_best_decoding
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ get_texts,
+ load_averaged_model,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="",
+ help="OTC token",
+ )
+
+ parser.add_argument(
+ "--blank-bias",
+ type=float,
+ default=0,
+ help="bias (log-prob) added to blank token during decoding",
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=5,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="ctc-greedy-search",
+ help="""Decoding method.
+ Supported values are:
+ - (0) 1best. Extract the best path from the decoding lattice as the
+ decoding result.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--num-decoder-layers",
+ type=int,
+ default=0,
+ help="""Number of decoder layer of transformer decoder.
+ Setting this to 0 will not create the decoder at all (pure CTC model)
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_phone",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--lm-dir",
+ type=str,
+ default="data/lm",
+ help="""The n-gram LM dir.
+ It should contain either G_4_gram.pt or G_4_gram.fst.txt
+ """,
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "feature_dim": 80,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "encoder_dim": 512,
+ "num_encoder_layers": 12,
+ # parameters for decoding
+ "search_beam": 20,
+ "output_beam": 8,
+ "min_active_states": 30,
+ "max_active_states": 10000,
+ "use_double_scores": True,
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
+ # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
+ new_hyp: List[int] = []
+ cur = 0
+ while cur < len(hyp):
+ if hyp[cur] != 0:
+ new_hyp.append(hyp[cur])
+ prev = cur
+ while cur < len(hyp) and hyp[cur] == hyp[prev]:
+ cur += 1
+ return new_hyp
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: k2.Fsa,
+ batch: dict,
+ word_table: k2.SymbolTable,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if no rescoring is used, the key is the string `no_rescore`.
+ If LM rescoring is used, the key is the string `lm_scale_xxx`,
+ where `xxx` is the value of `lm_scale`. An example key is
+ `lm_scale_0.7`
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding without LM rescoring.
+
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict. Note: If it decodes to nothing, then return None.
+ """
+ device = HLG.device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+
+ nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
+ # nnet_output is (N, T, C)
+ nnet_output[:, :, 0] += params.blank_bias
+
+ supervision_segments = torch.stack(
+ (
+ supervisions["sequence_idx"],
+ torch.div(
+ supervisions["start_frame"],
+ params.subsampling_factor,
+ rounding_mode="trunc",
+ ),
+ torch.div(
+ supervisions["num_frames"],
+ params.subsampling_factor,
+ rounding_mode="trunc",
+ ),
+ ),
+ 1,
+ ).to(torch.int32)
+
+ decoding_graph = HLG
+
+ lattice = get_lattice(
+ nnet_output=nnet_output,
+ decoding_graph=decoding_graph,
+ supervision_segments=supervision_segments,
+ search_beam=params.search_beam,
+ output_beam=params.output_beam,
+ min_active_states=params.min_active_states,
+ max_active_states=params.max_active_states,
+ subsampling_factor=params.subsampling_factor + 2,
+ )
+
+ if params.method in ["1best"]:
+ best_path = one_best_decoding(
+ lattice=lattice, use_double_scores=params.use_double_scores
+ )
+ key = "no_rescore"
+
+ hyps = get_texts(best_path)
+ hyps = [[word_table[i] for i in ids] for ids in hyps]
+
+ return {key: hyps}
+ else:
+ assert False, f"Unsupported decoding method: {params.method}"
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ HLG: k2.Fsa,
+ word_table: k2.SymbolTable,
+ G: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ HLG:
+ The decoding graph. Used only when params.method is NOT ctc-decoding.
+ H:
+ The ctc topo. Used only when params.method is ctc-decoding.
+ bpe_model:
+ The BPE model. Used only when params.method is ctc-decoding.
+ word_table:
+ It is the word symbol table.
+ sos_id:
+ The token ID for SOS.
+ eos_id:
+ The token ID for EOS.
+ G:
+ An LM. It is not None when params.method is "nbest-rescoring"
+ or "whole-lattice-rescoring". In general, the G in HLG
+ is a 3-gram LM, while this G is a 4-gram LM.
+ Returns:
+ Return a dict, whose key may be "no-rescore" if no LM rescoring
+ is used, or it may be "lm_scale_0.7" if LM rescoring is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ HLG=HLG,
+ batch=batch,
+ word_table=word_table,
+ G=G,
+ )
+
+ if hyps_dict is not None:
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+ else:
+ assert len(results) > 0, "It should not decode to empty in the first batch!"
+ this_batch = []
+ hyp_words = []
+ for ref_text in texts:
+ ref_words = ref_text.split()
+ this_batch.append((ref_words, hyp_words))
+
+ for lm_scale in results.keys():
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ if params.method in ("attention-decoder", "rnn-lm"):
+ # Set it to False since there are too many logs.
+ enable_log = False
+ else:
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=enable_log
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+ args.lm_dir = Path(args.lm_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ lexicon = Lexicon(params.lang_dir)
+ # remove otc_token from decoding units
+ max_token_id = len(lexicon.tokens) - 1
+ num_classes = max_token_id + 1 # +1 for the blank
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ params.num_classes = num_classes
+
+ HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+ HLG = HLG.to(device)
+ assert HLG.requires_grad is False
+
+ if not hasattr(HLG, "lm_scores"):
+ HLG.lm_scores = HLG.scores.clone()
+
+ G = None
+
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.encoder_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_encoder_layers=params.num_encoder_layers,
+ num_decoder_layers=params.num_decoder_layers,
+ )
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ test_clean_cuts = librispeech.test_clean_cuts()
+ test_other_cuts = librispeech.test_other_cuts()
+
+ test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
+ test_other_dl = librispeech.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ HLG=HLG,
+ word_table=lexicon.word_table,
+ )
+
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py
index fe6c5af91..82c68803f 100755
--- a/egs/librispeech/WSASR/conformer_ctc2/train.py
+++ b/egs/librispeech/WSASR/conformer_ctc2/train.py
@@ -31,6 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_200 \
--otc-token "" \
+ --feature-dim 768 \
--allow-bypass-arc true \
--allow-self-loop-arc true \
--initial-bypass-weight -19 \
@@ -160,6 +161,14 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--feature-dim",
+ type=int,
+ default=768,
+ help="""Number of features extracted in feature extraction stage.last dimension of feature vector.
+ 80 when using fbank features and 768 or 1024 whn using wave2vec""",
+ )
+
parser.add_argument(
"--initial-lr",
type=float,
@@ -385,7 +394,6 @@ def get_params() -> AttributeDict:
"valid_interval": 800, # For the 100h subset, use 800
"alignment_interval": 25,
# parameters for conformer
- "feature_dim": 768,
"subsampling_factor": 2,
"encoder_dim": 512,
"nhead": 8,
diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py
new file mode 100755
index 000000000..b276d0587
--- /dev/null
+++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py
@@ -0,0 +1,1124 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Quandong Wang)
+# 2023 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./conformer_ctc2/train.py \
+ --world-size 4 \
+ --manifest-dir data/ssl \
+ --train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \
+ --exp-dir conformer_ctc2/exp \
+ --lang-dir data/lang_bpe_200 \
+ --otc-token "" \
+ --feature-dim 768 \
+ --allow-bypass-arc true \
+ --allow-self-loop-arc true \
+ --initial-bypass-weight -19 \
+ --initial-self-loop-weight 3.75 \
+ --bypass-weight-decay 0.975 \
+ --self-loop-weight-decay 0.999 \
+ --show-alignment true
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from conformer import Conformer
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, Eve
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.decode import one_best_decoding
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
+from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ encode_supervisions_otc,
+ get_texts,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=20,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conformer_ctc2/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_bpe_200",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--feature-dim",
+ type=int,
+ default=80,
+ help="""Number of features extracted in feature extraction stage.last dimension of feature vector.
+ 80 when using fbank features and 768 or 1024 whn using wave2vec""",
+ )
+
+ parser.add_argument(
+ "--initial-lr",
+ type=float,
+ default=0.003,
+ help="""The initial learning rate. This value should not need to be
+ changed.""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate decreases.
+ We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--att-rate",
+ type=float,
+ default=0.0,
+ help="""The attention rate.
+ The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
+ """,
+ )
+
+ parser.add_argument(
+ "--num-decoder-layers",
+ type=int,
+ default=0,
+ help="""Number of decoder layer of transformer decoder.
+ Setting this to 0 will not create the decoder at all (pure CTC model)
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=10,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="_",
+ help="OTC token",
+ )
+
+ parser.add_argument(
+ "--allow-bypass-arc",
+ type=str2bool,
+ default=True,
+ help="""Whether to add bypass arc to training graph for substitution
+ and insertion errors (wrong or extra words in the transcript).""",
+ )
+
+ parser.add_argument(
+ "--allow-self-loop-arc",
+ type=str2bool,
+ default=True,
+ help="""Whether to self-loop bypass arc to training graph for deletion errors
+ (missing words in the transcript).""",
+ )
+
+ parser.add_argument(
+ "--initial-bypass-weight",
+ type=float,
+ default=0.0,
+ help="Initial weight associated with bypass arc",
+ )
+
+ parser.add_argument(
+ "--initial-self-loop-weight",
+ type=float,
+ default=0.0,
+ help="Initial weight associated with self-loop arc",
+ )
+
+ parser.add_argument(
+ "--bypass-weight-decay",
+ type=float,
+ default=1.0,
+ help="""Weight decay factor of bypass arc weight:
+ bypass_arc_weight = intial_bypass_weight * bypass_weight_decay ^ ith-epoch""",
+ )
+
+ parser.add_argument(
+ "--self-loop-weight-decay",
+ type=float,
+ default=1.0,
+ help="""Weight decay factor of self-loop arc weight:
+ self_loop_arc_weight = intial_self_loop_weight * self_loop_weight_decay ^ ith-epoch""",
+ )
+
+ parser.add_argument(
+ "--show-alignment",
+ type=str2bool,
+ default=True,
+ help="Whether to print OTC alignment during training",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - beam_size: It is used in k2.ctc_loss
+
+ - reduction: It is used in k2.ctc_loss
+
+ - use_double_scores: It is used in k2.ctc_loss
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 1,
+ "reset_interval": 200,
+ "valid_interval": 800, # For the 100h subset, use 800
+ "alignment_interval": 100,
+ # parameters for conformer
+ "subsampling_factor": 4,
+ "encoder_dim": 512,
+ "nhead": 8,
+ "dim_feedforward": 2048,
+ "num_encoder_layers": 12,
+ # parameters for ctc loss
+ "beam_size": 10,
+ "reduction": "sum",
+ "use_double_scores": True,
+ # parameters for Noam
+ "model_warm_step": 3000, # arg given to model, not for lrate
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ graph_compiler: OtcPhoneTrainingGraphCompiler,
+ is_training: bool,
+ warmup: float = 2.0,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute OTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ graph_compiler:
+ It is used to build a decoding graph from a ctc topo and training
+ transcript. The training transcript is contained in the given `batch`,
+ while the ctc topo is built when this compiler is instantiated.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ with torch.set_grad_enabled(is_training):
+ nnet_output, encoder_memory, memory_mask = model(
+ feature, supervisions, warmup=warmup
+ )
+ # Set the probability of OTC token as the average of non-blank tokens
+ # under the assumption that blank is the first and
+ # OTC token is the last token in tokens.txt
+ _, _, V = nnet_output.shape
+
+ otc_token_log_prob = torch.logsumexp(
+ nnet_output[:, :, 1:], dim=-1, keepdim=True
+ ) - torch.log(torch.tensor([V - 1])).to(device)
+
+ nnet_output = torch.cat([nnet_output, otc_token_log_prob], dim=-1)
+
+ # NOTE: We need `encode_supervisions` to sort sequences with
+ # different duration in decreasing order, required by
+ # `k2.intersect_dense` called in `k2.ctc_loss`
+ supervision_segments, texts, utt_ids, verbatim_texts = encode_supervisions_otc(
+ supervisions, subsampling_factor=params.subsampling_factor
+ )
+
+ bypass_weight = graph_compiler.initial_bypass_weight * (
+ graph_compiler.bypass_weight_decay ** (params.cur_epoch - 1)
+ )
+ self_loop_weight = graph_compiler.initial_self_loop_weight * (
+ graph_compiler.self_loop_weight_decay ** (params.cur_epoch - 1)
+ )
+
+ decoding_graph = graph_compiler.compile(
+ texts=texts,
+ allow_bypass_arc=params.allow_bypass_arc,
+ allow_self_loop_arc=params.allow_self_loop_arc,
+ bypass_weight=bypass_weight,
+ self_loop_weight=self_loop_weight,
+ )
+
+ dense_fsa_vec = k2.DenseFsaVec(
+ nnet_output,
+ supervision_segments,
+ allow_truncate=3,
+ )
+
+ otc_loss = k2.ctc_loss(
+ decoding_graph=decoding_graph,
+ dense_fsa_vec=dense_fsa_vec,
+ output_beam=params.beam_size,
+ reduction=params.reduction,
+ use_double_scores=params.use_double_scores,
+ )
+
+ assert params.att_rate == 0.0
+ loss = otc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+ info["otc_loss"] = otc_loss.detach().cpu().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
+ info["utterances"] = feature.size(0)
+ # averaged input duration in frames over utterances
+ info["utt_duration"] = feature_lens.sum().item()
+ # averaged padding proportion over utterances
+ info["utt_pad_proportion"] = (
+ ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+ )
+
+ if params.show_alignment:
+ if params.batch_idx_train % params.alignment_interval == 0:
+ for index, utt_id in enumerate(utt_ids):
+ verbatim_text = verbatim_texts[index]
+ utt_id = utt_ids[index]
+
+ lattice = k2.intersect_dense(
+ decoding_graph,
+ dense_fsa_vec,
+ params.beam_size,
+ )
+ best_path = one_best_decoding(
+ lattice=lattice,
+ use_double_scores=params.use_double_scores,
+ )
+ hyp_ids = get_texts(best_path)[index]
+ hyp_text_list = [graph_compiler.word_table[i] for i in hyp_ids]
+ hyp_text = " ".join(hyp_text_list)
+
+ logging.info(f"[utterance id]: {utt_id}")
+ logging.info(f"[verbatim text]: {verbatim_text}")
+ logging.info(f"[best alignment]: {hyp_text}")
+ logging.info(bypass_weight)
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: OtcPhoneTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: OtcPhoneTrainingGraphCompiler,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ graph_compiler:
+ It is used to convert transcripts to FSAs.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ warmup=(params.batch_idx_train / params.model_warm_step),
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ # scaler.scale(loss).backward()
+
+ try:
+ # loss.backward()
+ scaler.scale(loss).backward()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(f"failing batch size:{batch_size} ")
+ raise
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+
+ if params.print_diagnostics and batch_idx == 30:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}"
+ )
+ if loss_info["otc_loss"] == float("inf"):
+ logging.error("Your loss contains inf, something goes wrong")
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+ params.valid_interval = 1600
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = OtcPhoneTrainingGraphCompiler(
+ lexicon,
+ otc_token=params.otc_token,
+ device=device,
+ initial_bypass_weight=params.initial_bypass_weight,
+ initial_self_loop_weight=params.initial_self_loop_weight,
+ bypass_weight_decay=params.bypass_weight_decay,
+ self_loop_weight_decay=params.self_loop_weight_decay,
+ )
+
+ # remove OTC token as it is the average of all non-blank tokens
+ max_token_id = graph_compiler.get_max_token_id() - 1
+ # add blank
+ num_classes = max_token_id + 1
+
+ logging.info("About to create model")
+ model = Conformer(
+ num_features=params.feature_dim,
+ nhead=params.nhead,
+ d_model=params.encoder_dim,
+ num_classes=num_classes,
+ subsampling_factor=params.subsampling_factor,
+ num_encoder_layers=params.num_encoder_layers,
+ num_decoder_layers=params.num_decoder_layers,
+ )
+
+ print(model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank])
+
+ optimizer = Eve(model.parameters(), lr=params.initial_lr)
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ diagnostic = diagnostics.attach_diagnostics(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ train_cuts = librispeech.train_clean_100_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ return 1.0 <= c.duration <= 20.0
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+ if params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: OtcPhoneTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ # warmup = 0.0 is so that the derivs for the pruned loss stay zero
+ # (i.e. are not remembered by the decaying-average in adam), because
+ # we want to avoid these params being subject to shrinkage in adam.
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ graph_compiler=graph_compiler,
+ is_training=True,
+ warmup=0.0,
+ )
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.otc_token = f"{args.otc_token}"
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/WSASR/local/download_lm.py b/egs/librispeech/WSASR/local/download_lm.py
new file mode 100755
index 000000000..5a36ff2a9
--- /dev/null
+++ b/egs/librispeech/WSASR/local/download_lm.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file downloads the following LibriSpeech LM files:
+
+ - 3-gram.pruned.1e-7.arpa.gz
+ - 4-gram.arpa.gz
+ - librispeech-vocab.txt
+ - librispeech-lexicon.txt
+ - librispeech-lm-norm.txt.gz
+
+from http://www.openslr.org/resources/11
+and save them in the user provided directory.
+
+Files are not re-downloaded if they already exist.
+
+Usage:
+ ./local/download_lm.py --out-dir ./download/lm
+"""
+
+import argparse
+import gzip
+import logging
+import os
+import shutil
+from pathlib import Path
+
+from tqdm.auto import tqdm
+
+
+# This function is copied from lhotse
+def tqdm_urlretrieve_hook(t):
+ """Wraps tqdm instance.
+ Don't forget to close() or __exit__()
+ the tqdm instance once you're done with it (easiest using `with` syntax).
+ Example
+ -------
+ >>> from urllib.request import urlretrieve
+ >>> with tqdm(...) as t:
+ ... reporthook = tqdm_urlretrieve_hook(t)
+ ... urlretrieve(..., reporthook=reporthook)
+
+ Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
+ """
+ last_b = [0]
+
+ def update_to(b=1, bsize=1, tsize=None):
+ """
+ b : int, optional
+ Number of blocks transferred so far [default: 1].
+ bsize : int, optional
+ Size of each block (in tqdm units) [default: 1].
+ tsize : int, optional
+ Total size (in tqdm units). If [default: None] or -1,
+ remains unchanged.
+ """
+ if tsize not in (None, -1):
+ t.total = tsize
+ displayed = t.update((b - last_b[0]) * bsize)
+ last_b[0] = b
+ return displayed
+
+ return update_to
+
+
+# This function is copied from lhotse
+def urlretrieve_progress(url, filename=None, data=None, desc=None):
+ """
+ Works exactly like urllib.request.urlretrieve, but attaches a tqdm hook to
+ display a progress bar of the download.
+ Use "desc" argument to display a user-readable string that informs what is
+ being downloaded.
+ """
+ from urllib.request import urlretrieve
+
+ with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc) as t:
+ reporthook = tqdm_urlretrieve_hook(t)
+ return urlretrieve(url=url, filename=filename, reporthook=reporthook, data=data)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--out-dir", type=str, help="Output directory.")
+
+ args = parser.parse_args()
+ return args
+
+
+def main(out_dir: str):
+ url = "http://www.openslr.org/resources/11"
+ out_dir = Path(out_dir)
+
+ files_to_download = (
+ "3-gram.pruned.1e-7.arpa.gz",
+ "4-gram.arpa.gz",
+ "librispeech-vocab.txt",
+ "librispeech-lexicon.txt",
+ "librispeech-lm-norm.txt.gz",
+ )
+
+ for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
+ filename = out_dir / f
+ if filename.is_file() is False:
+ urlretrieve_progress(
+ f"{url}/{f}",
+ filename=filename,
+ desc=f"Downloading {filename}",
+ )
+ else:
+ logging.info(f"{filename} already exists - skipping")
+
+ if ".gz" in str(filename):
+ unzipped = Path(os.path.splitext(filename)[0])
+ if unzipped.is_file() is False:
+ with gzip.open(filename, "rb") as f_in:
+ with open(unzipped, "wb") as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ else:
+ logging.info(f"{unzipped} already exist - skipping")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ logging.info(f"out_dir: {args.out_dir}")
+
+ main(out_dir=args.out_dir)
diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang.py b/egs/librispeech/WSASR/local/prepare_otc_lang.py
new file mode 100755
index 000000000..01865b865
--- /dev/null
+++ b/egs/librispeech/WSASR/local/prepare_otc_lang.py
@@ -0,0 +1,469 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2024 Johns Hopkins University (author: Dongji Gao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import logging
+import math
+import re
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import write_lexicon
+from icefall.utils import str2bool
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ It should contain a file lexicon.txt.
+ Generated files by this script are saved into this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--otc-token",
+ type=str,
+ default="",
+ help="The OTC token in lexicon",
+ )
+
+ parser.add_argument(
+ "--debug",
+ type=str2bool,
+ default=False,
+ help="""True for debugging, which will generate
+ a visualization of the lexicon FST.
+
+ Caution: If your lexicon contains hundreds of thousands
+ of lines, please set it to False!
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def read_lexicon(
+ filename: str,
+) -> List[Tuple[str, List[str]]]:
+ """Read a lexicon from `filename`.
+
+ Each line in the lexicon contains "word p1 p2 p3 ...".
+ That is, the first field is a word and the remaining
+ fields are tokens. Fields are separated by space(s).
+
+ Args:
+ filename:
+ Path to the lexicon.txt
+
+ Returns:
+ A list of tuples., e.g., [('w', ['p1', 'p2']), ('w1', ['p3, 'p4'])]
+ """
+ ans = []
+
+ with open(filename, "r", encoding="utf-8") as f:
+ whitespace = re.compile("[ \t]+")
+ for line in f:
+ a = whitespace.split(line.strip(" \t\r\n"))
+ if len(a) == 0:
+ continue
+
+ if len(a) < 2:
+ logging.info(f"Found bad line {line} in lexicon file {filename}")
+ logging.info("Every line is expected to contain at least 2 fields")
+ continue
+ word = a[0]
+ if word == "