diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index cf0523401..1b6d0026f 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -55,9 +55,9 @@ RUN pip install --no-cache-dir \
"numpy<2.0" \
onnxoptimizer \
onnxsim \
- onnx \
+ onnx==1.17.0 \
onnxmltools \
- onnxruntime \
+ onnxruntime==1.17.1 \
piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
pypinyin==0.50.0 \
pytest \
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index 638e19498..7f36e278d 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -63,23 +63,24 @@ def get_torchaudio_version(torch_version):
def get_matrix(min_torch_version, specified_torch_version, specified_python_version):
- k2_version = "1.24.4.dev20241029"
- kaldifeat_version = "1.25.5.dev20241029"
- version = "20241218"
+ k2_version = "1.24.4.dev20250630"
+ kaldifeat_version = "1.25.5.dev20250630"
+ version = "20250630"
# torchaudio 2.5.0 does not support python 3.13
- python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
+ python_version = ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
torch_version = []
torch_version += ["1.13.0", "1.13.1"]
torch_version += ["2.0.0", "2.0.1"]
- # torch_version += ["2.1.0", "2.1.1", "2.1.2"]
- # torch_version += ["2.2.0", "2.2.1", "2.2.2"]
+ torch_version += ["2.1.0", "2.1.1", "2.1.2"]
+ torch_version += ["2.2.0", "2.2.1", "2.2.2"]
# Test only torch >= 2.3.0
torch_version += ["2.3.0", "2.3.1"]
torch_version += ["2.4.0"]
torch_version += ["2.4.1"]
torch_version += ["2.5.0"]
torch_version += ["2.5.1"]
+ torch_version += ["2.6.0", "2.7.0", "2.7.1"]
if specified_torch_version:
torch_version = [specified_torch_version]
@@ -109,12 +110,8 @@ def get_matrix(min_torch_version, specified_torch_version, specified_python_vers
# torch>=2.5 requires python 3.10
continue
- if t == "2.5.1":
- k2_version_2 = "1.24.4.dev20241122"
- kaldifeat_version_2 = "1.25.5.dev20241126"
- else:
- k2_version_2 = k2_version
- kaldifeat_version_2 = kaldifeat_version
+ k2_version_2 = k2_version
+ kaldifeat_version_2 = kaldifeat_version
matrix.append(
{
diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py
index 3784d5fa5..e268acf03 100755
--- a/.github/scripts/generate-piper-phonemize-page.py
+++ b/.github/scripts/generate-piper-phonemize-page.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
-def main():
+def get_v1_2_0_files():
prefix = (
"https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/"
)
@@ -19,9 +19,70 @@ def main():
"piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl",
"piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
]
+ ans = [prefix + f for f in files]
+ ans.sort()
+ return ans
+
+
+def get_v1_3_0_files():
+ prefix = (
+ "https://github.com/csukuangfj/piper-phonemize/releases/download/2025.06.23/"
+ )
+ files = [
+ "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_universal2.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp310-cp310-win_amd64.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_universal2.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp311-cp311-win_amd64.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_universal2.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_x86_64.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp312-cp312-win_amd64.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_universal2.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp313-cp313-win_amd64.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_universal2.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp38-cp38-win_amd64.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_universal2.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-macosx_11_0_arm64.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
+ "piper_phonemize-1.3.0-cp39-cp39-win_amd64.whl",
+ ]
+ ans = [prefix + f for f in files]
+ ans.sort()
+ return ans
+
+
+def main():
+ files = get_v1_3_0_files() + get_v1_2_0_files()
+
with open("piper_phonemize.html", "w") as f:
- for file in files:
- url = prefix + file
+ for url in files:
+ file = url.split("/")[-1]
f.write(f'{file}
\n')
diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh
deleted file mode 100755
index e254419ff..000000000
--- a/.github/scripts/multi-zh-hans.sh
+++ /dev/null
@@ -1,200 +0,0 @@
-#!/usr/bin/env bash
-
-set -ex
-
-git config --global user.name "k2-fsa"
-git config --global user.email "csukuangfj@gmail.com"
-git config --global lfs.allowincompletepush true
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-log "pwd: $PWD"
-
-cd egs/multi_zh-hans/ASR
-
-repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-pushd $repo
-cd exp
-git lfs pull --include pretrained.pt
-ln -s pretrained.pt epoch-99.pt
-cd ../data/lang_bpe_2000
-ls -lh
-git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
-git lfs pull --include "*.model"
-ls -lh
-popd
-
-log "--------------------------------------------"
-log "Export non-streaming ONNX transducer models "
-log "--------------------------------------------"
-./zipformer/export-onnx.py \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- --use-averaged-model 0 \
- --epoch 99 \
- --avg 1 \
- --exp-dir $repo/exp \
- --causal False
-
-ls -lh $repo/exp
-
-./zipformer/onnx_pretrained.py \
- --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
- --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
- --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- $repo/test_wavs/DEV_T0000000000.wav \
- $repo/test_wavs/DEV_T0000000001.wav \
- $repo/test_wavs/DEV_T0000000002.wav \
- $repo/test_wavs/TEST_MEETING_T0000000113.wav \
- $repo/test_wavs/TEST_MEETING_T0000000219.wav \
- $repo/test_wavs/TEST_MEETING_T0000000351.wav
-
-rm -rf $repo
-
-repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-pushd $repo
-cd exp/
-git lfs pull --include pretrained.pt
-rm -fv epoch-20.pt
-rm -fv *.onnx
-ln -s pretrained.pt epoch-20.pt
-cd ../data/lang_bpe_2000
-ls -lh
-git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
-git lfs pull --include "*.model"
-ls -lh
-popd
-
-log "----------------------------------------"
-log "Export streaming ONNX CTC models "
-log "----------------------------------------"
-./zipformer/export-onnx-streaming-ctc.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- --causal 1 \
- --avg 1 \
- --epoch 20 \
- --use-averaged-model 0 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --use-ctc 1
-
-ls -lh $repo/exp/
-
-log "------------------------------------------------------------"
-log "Test exported streaming ONNX CTC models (greedy search) "
-log "------------------------------------------------------------"
-
-test_wavs=(
-DEV_T0000000000.wav
-DEV_T0000000001.wav
-DEV_T0000000002.wav
-TEST_MEETING_T0000000113.wav
-TEST_MEETING_T0000000219.wav
-TEST_MEETING_T0000000351.wav
-)
-
-for w in ${test_wavs[@]}; do
- ./zipformer/onnx_pretrained-streaming-ctc.py \
- --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- $repo/test_wavs/$w
-done
-
-log "Upload onnx CTC models to huggingface"
-url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
-GIT_LFS_SKIP_SMUDGE=1 git clone $url
-dst=$(basename $url)
-cp -v $repo/exp/ctc*.onnx $dst
-cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
-cp -v $repo/data/lang_bpe_2000/bpe.model $dst
-mkdir -p $dst/test_wavs
-cp -v $repo/test_wavs/*.wav $dst/test_wavs
-cd $dst
-git lfs track "*.onnx" "bpe.model"
-ls -lh
-file bpe.model
-git status
-git add .
-git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
-
-log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
-rm -rf .git
-rm -fv .gitattributes
-cd ..
-tar cjfv $dst.tar.bz2 $dst
-ls -lh *.tar.bz2
-mv -v $dst.tar.bz2 ../../../
-
-log "----------------------------------------"
-log "Export streaming ONNX transducer models "
-log "----------------------------------------"
-
-./zipformer/export-onnx-streaming.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- --causal 1 \
- --avg 1 \
- --epoch 20 \
- --use-averaged-model 0 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --use-ctc 0
-
-ls -lh $repo/exp
-
-log "------------------------------------------------------------"
-log "Test exported streaming ONNX transducer models (Python code)"
-log "------------------------------------------------------------"
-
-log "test fp32"
-./zipformer/onnx_pretrained-streaming.py \
- --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \
- --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
- --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- $repo/test_wavs/DEV_T0000000000.wav
-
-log "test int8"
-./zipformer/onnx_pretrained-streaming.py \
- --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
- --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
- --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
- --tokens $repo/data/lang_bpe_2000/tokens.txt \
- $repo/test_wavs/DEV_T0000000000.wav
-
-log "Upload onnx transducer models to huggingface"
-
-url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12
-GIT_LFS_SKIP_SMUDGE=1 git clone $url
-dst=$(basename $url)
-cp -v $repo/exp/encoder*.onnx $dst
-cp -v $repo/exp/decoder*.onnx $dst
-cp -v $repo/exp/joiner*.onnx $dst
-cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
-cp -v $repo/data/lang_bpe_2000/bpe.model $dst
-mkdir -p $dst/test_wavs
-cp -v $repo/test_wavs/*.wav $dst/test_wavs
-cd $dst
-git lfs track "*.onnx" bpe.model
-git add .
-git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
-
-log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
-rm -rf .git
-rm -fv .gitattributes
-cd ..
-tar cjfv $dst.tar.bz2 $dst
-ls -lh *.tar.bz2
-mv -v $dst.tar.bz2 ../../../
diff --git a/.github/scripts/multi_zh-hans/ASR/run.sh b/.github/scripts/multi_zh-hans/ASR/run.sh
new file mode 100755
index 000000000..345b64cf0
--- /dev/null
+++ b/.github/scripts/multi_zh-hans/ASR/run.sh
@@ -0,0 +1,756 @@
+#!/usr/bin/env bash
+
+set -ex
+
+git config --global user.name "k2-fsa"
+git config --global user.email "csukuangfj@gmail.com"
+git config --global lfs.allowincompletepush true
+
+python3 -m pip install onnxmltools==1.13.0 onnx==1.17.0 onnxruntime==1.17.1 sherpa-onnx
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/multi_zh-hans/ASR
+
+log "pwd: $PWD"
+
+function run_2023_9_2() {
+ repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+ pushd $repo
+ cd exp
+ git lfs pull --include pretrained.pt
+ ln -s pretrained.pt epoch-99.pt
+ cd ../data/lang_bpe_2000
+ ls -lh
+ git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
+ git lfs pull --include "*.model"
+ ls -lh
+ popd
+
+ log "--------------------------------------------"
+ log "Export non-streaming ONNX transducer models "
+ log "--------------------------------------------"
+ ./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --causal False \
+ --fp16 1
+
+ ls -lh $repo/exp
+
+ ./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav \
+ $repo/test_wavs/DEV_T0000000001.wav \
+ $repo/test_wavs/DEV_T0000000002.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000113.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000219.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000351.wav
+
+ ./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.int8.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.int8.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav \
+ $repo/test_wavs/DEV_T0000000001.wav \
+ $repo/test_wavs/DEV_T0000000002.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000113.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000219.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000351.wav
+
+ ./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.fp16.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.fp16.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.fp16.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav \
+ $repo/test_wavs/DEV_T0000000001.wav \
+ $repo/test_wavs/DEV_T0000000002.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000113.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000219.wav \
+ $repo/test_wavs/TEST_MEETING_T0000000351.wav
+
+ rm -rf $repo
+}
+
+function run_2023_11_05_streaming() {
+ repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ pushd $repo
+ cd exp/
+ git lfs pull --include pretrained.pt
+ rm -fv epoch-20.pt
+ rm -fv *.onnx
+ ln -s pretrained.pt epoch-20.pt
+ cd ../data/lang_bpe_2000
+ ls -lh
+ git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
+ git lfs pull --include "*.model"
+ ls -lh
+ popd
+
+ log "----------------------------------------"
+ log "Export streaming ONNX CTC models "
+ log "----------------------------------------"
+ ./zipformer/export-onnx-streaming-ctc.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --causal 1 \
+ --avg 1 \
+ --epoch 20 \
+ --use-averaged-model 0 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 1 \
+ --fp16 1
+
+ ls -lh $repo/exp/
+
+ log "------------------------------------------------------------"
+ log "Test exported streaming ONNX CTC models (greedy search) "
+ log "------------------------------------------------------------"
+
+ test_wavs=(
+ DEV_T0000000000.wav
+ DEV_T0000000001.wav
+ DEV_T0000000002.wav
+ TEST_MEETING_T0000000113.wav
+ TEST_MEETING_T0000000219.wav
+ TEST_MEETING_T0000000351.wav
+ )
+
+ for w in ${test_wavs[@]}; do
+ log "----fp32----"
+ ./zipformer/onnx_pretrained-streaming-ctc.py \
+ --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/$w
+
+ log "----int8----"
+
+ ./zipformer/onnx_pretrained-streaming-ctc.py \
+ --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/$w
+
+ log "----fp16----"
+
+ ./zipformer/onnx_pretrained-streaming-ctc.py \
+ --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/$w
+ done
+
+ log "Upload onnx CTC models to huggingface"
+ name=(
+ sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
+ sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-int8-2023-12-13
+ sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-fp16-2023-12-13
+ )
+ for n in ${name[@]}; do
+ url=https://huggingface.co/k2-fsa/$n
+ GIT_LFS_SKIP_SMUDGE=1 git clone $url
+ dst=$(basename $url)
+ if [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]]; then
+ cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx $dst
+ elif [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-int8-2023-12-13 ]]; then
+ cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst
+ elif [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-fp16-2023-12-13 ]]; then
+ cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst
+ fi
+
+ cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+ cp -v $repo/data/lang_bpe_2000/bpe.model $dst
+ mkdir -p $dst/test_wavs
+ cp -v $repo/test_wavs/*.wav $dst/test_wavs
+ cd $dst
+ git lfs track "*.onnx" "bpe.model" "*.wav"
+ ls -lh
+ file bpe.model
+ git status
+ git add .
+ git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
+
+ log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+ rm -rf .git
+ rm -fv .gitattributes
+ cd ..
+ tar cjfv $dst.tar.bz2 $dst
+ ls -lh *.tar.bz2
+ mv -v $dst.tar.bz2 ../../../
+ done
+
+ log "----------------------------------------"
+ log "Export streaming ONNX transducer models "
+ log "----------------------------------------"
+
+ ./zipformer/export-onnx-streaming.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --causal 1 \
+ --avg 1 \
+ --epoch 20 \
+ --use-averaged-model 0 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 0 \
+ --fp16 1
+
+ ls -lh $repo/exp
+
+ log "------------------------------------------------------------"
+ log "Test exported streaming ONNX transducer models (Python code)"
+ log "------------------------------------------------------------"
+
+ log "test fp32"
+ ./zipformer/onnx_pretrained-streaming.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav
+
+ log "test int8"
+ ./zipformer/onnx_pretrained-streaming.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav
+
+ log "test fp16"
+ ./zipformer/onnx_pretrained-streaming.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav
+
+ name=(
+ sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-13
+ sherpa-onnx-streaming-zipformer-multi-zh-hans-int8-2023-12-13
+ sherpa-onnx-streaming-zipformer-multi-zh-hans-fp16-2023-12-13
+ )
+
+ for n in ${name[@]}; do
+ url=https://huggingface.co/csukuangfj/$n
+ GIT_LFS_SKIP_SMUDGE=1 git clone $url
+ dst=$(basename $url)
+ if [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-13 ]]; then
+ cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst
+ cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst
+ cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx $dst
+ elif [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-int8-2023-12-13 ]]; then
+ cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst
+ cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst
+ cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst
+ elif [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-fp16-2023-12-13 ]]; then
+ cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst
+ cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst
+ cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst
+ fi
+
+ cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+ cp -v $repo/data/lang_bpe_2000/bpe.model $dst
+ mkdir -p $dst/test_wavs
+ cp -v $repo/test_wavs/*.wav $dst/test_wavs
+ cd $dst
+ git lfs track "*.onnx" "bpe.model" "*.wav"
+ ls -lh
+ file bpe.model
+ git status
+ git add .
+ git commit -m "upload model" && git push https://csukuangfj:${HF_TOKEN}@huggingface.co/csukuangfj/$dst main || true
+
+ log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+ rm -rf .git
+ rm -fv .gitattributes
+ cd ..
+ tar cjfv $dst.tar.bz2 $dst
+ ls -lh *.tar.bz2
+ mv -v $dst.tar.bz2 ../../../
+ done
+}
+
+function run_2023_12_12_streaming() {
+ log "Upload onnx transducer models to huggingface"
+
+ url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12
+ GIT_LFS_SKIP_SMUDGE=1 git clone $url
+ dst=$(basename $url)
+ cp -v $repo/exp/encoder*.onnx $dst
+ cp -v $repo/exp/decoder*.onnx $dst
+ cp -v $repo/exp/joiner*.onnx $dst
+ cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+ cp -v $repo/data/lang_bpe_2000/bpe.model $dst
+ mkdir -p $dst/test_wavs
+ cp -v $repo/test_wavs/*.wav $dst/test_wavs
+ cd $dst
+ git lfs track "*.onnx" bpe.model "*.wav"
+ git add .
+ git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
+
+ log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+ rm -rf .git
+ rm -fv .gitattributes
+ cd ..
+ tar cjfv $dst.tar.bz2 $dst
+ ls -lh *.tar.bz2
+ mv -v $dst.tar.bz2 ../../../
+}
+
+function run_yuekai_large() {
+ repo_url=https://csukuangfj:${HF_TOKEN}@huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+ pushd $repo
+ git lfs pull --include pretrained.pt
+ mv pretrained.pt epoch-99.pt
+ curl -SL -O https://huggingface.co/pingzxy/icefall-asr-multi-zh-hans-zipformer-large-onnx/resolve/main/tokens.txt
+ popd
+
+ log "----------------------------------------"
+ log "Export streaming ONNX CTC models "
+ log "----------------------------------------"
+ ./zipformer/export-onnx-streaming-ctc.py \
+ --exp-dir $repo/ \
+ --tokens $repo/tokens.txt \
+ --causal 1 \
+ --avg 1 \
+ --epoch 99 \
+ --use-averaged-model 0 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 1 \
+ \
+ --num-encoder-layers 2,2,4,5,4,2 \
+ --feedforward-dim 768,1024,1536,2048,1536,768 \
+ --encoder-dim 256,384,512,768,512,256 \
+ --encoder-unmasked-dim 192,192,256,320,256,192 \
+ \
+ --fp16 1 \
+ --use-whisper-features 1
+
+
+ ls -lh $repo/
+ pushd $repo
+
+cat >README.md <README.md < DataLoader:
- """
- Args:
- cuts_train:
- CutSet for training.
- sampler_state_dict:
- The state dict for the training sampler.
- """
- transforms = []
- if self.args.enable_musan:
- logging.info("Enable MUSAN")
- logging.info("About to get Musan cuts")
- cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
- transforms.append(
- CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
- )
- else:
- logging.info("Disable MUSAN")
-
- if self.args.concatenate_cuts:
- logging.info(
- f"Using cut concatenation with duration factor "
- f"{self.args.duration_factor} and gap {self.args.gap}."
- )
- # Cut concatenation should be the first transform in the list,
- # so that if we e.g. mix noise in, it will fill the gaps between
- # different utterances.
- transforms = [
- CutConcatenate(
- duration_factor=self.args.duration_factor, gap=self.args.gap
- )
- ] + transforms
-
- input_transforms = []
- if self.args.enable_spec_aug:
- logging.info("Enable SpecAugment")
- logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
- # Set the value of num_frame_masks according to Lhotse's version.
- # In different Lhotse's versions, the default of num_frame_masks is
- # different.
- num_frame_masks = 10
- num_frame_masks_parameter = inspect.signature(
- SpecAugment.__init__
- ).parameters["num_frame_masks"]
- if num_frame_masks_parameter.default == 1:
- num_frame_masks = 2
- logging.info(f"Num frame mask: {num_frame_masks}")
- input_transforms.append(
- SpecAugment(
- time_warp_factor=self.args.spec_aug_time_warp_factor,
- num_frame_masks=num_frame_masks,
- features_mask_size=27,
- num_feature_masks=2,
- frames_mask_size=100,
- )
- )
- else:
- logging.info("Disable SpecAugment")
-
- logging.info("About to create train dataset")
- train = K2SpeechRecognitionDataset(
- input_strategy=eval(self.args.input_strategy)(),
- cut_transforms=transforms,
- input_transforms=input_transforms,
- return_cuts=self.args.return_cuts,
- )
-
- if self.args.on_the_fly_feats:
- # NOTE: the PerturbSpeed transform should be added only if we
- # remove it from data prep stage.
- # Add on-the-fly speed perturbation; since originally it would
- # have increased epoch size by 3, we will apply prob 2/3 and use
- # 3x more epochs.
- # Speed perturbation probably should come first before
- # concatenation, but in principle the transforms order doesn't have
- # to be strict (e.g. could be randomized)
- # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
- # Drop feats to be on the safe side.
- train = K2SpeechRecognitionDataset(
- cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
- input_transforms=input_transforms,
- return_cuts=self.args.return_cuts,
- )
-
- if self.args.bucketing_sampler:
- logging.info("Using DynamicBucketingSampler.")
- train_sampler = DynamicBucketingSampler(
- cuts_train,
- max_duration=self.args.max_duration,
- shuffle=self.args.shuffle,
- num_buckets=self.args.num_buckets,
- drop_last=self.args.drop_last,
- buffer_size=self.args.num_buckets * 2000,
- shuffle_buffer_size=self.args.num_buckets * 5000,
- )
- else:
- logging.info("Using SimpleCutSampler.")
- train_sampler = SimpleCutSampler(
- cuts_train,
- max_duration=self.args.max_duration,
- shuffle=self.args.shuffle,
- )
- logging.info("About to create train dataloader")
-
- if sampler_state_dict is not None:
- logging.info("Loading sampler state dict")
- train_sampler.load_state_dict(sampler_state_dict)
-
- # 'seed' is derived from the current random state, which will have
- # previously been set in the main process.
- seed = torch.randint(0, 100000, ()).item()
- worker_init_fn = _SeedWorkers(seed)
-
- train_dl = DataLoader(
- train,
- sampler=train_sampler,
- batch_size=None,
- num_workers=self.args.num_workers,
- persistent_workers=False,
- worker_init_fn=worker_init_fn,
- )
-
- return train_dl
-
- def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
- transforms = []
- if self.args.concatenate_cuts:
- transforms = [
- CutConcatenate(
- duration_factor=self.args.duration_factor, gap=self.args.gap
- )
- ] + transforms
-
- logging.info("About to create dev dataset")
- if self.args.on_the_fly_feats:
- validate = K2SpeechRecognitionDataset(
- cut_transforms=transforms,
- input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
- return_cuts=self.args.return_cuts,
- )
- else:
- validate = K2SpeechRecognitionDataset(
- cut_transforms=transforms,
- return_cuts=self.args.return_cuts,
- )
- valid_sampler = DynamicBucketingSampler(
- cuts_valid,
- max_duration=self.args.max_duration,
- num_buckets=self.args.num_buckets,
- buffer_size=self.args.num_buckets * 2000,
- shuffle=False,
- )
- logging.info("About to create dev dataloader")
- valid_dl = DataLoader(
- validate,
- sampler=valid_sampler,
- batch_size=None,
- num_workers=2,
- persistent_workers=False,
- )
-
- return valid_dl
-
- def test_dataloaders(self, cuts: CutSet) -> DataLoader:
- logging.debug("About to create test dataset")
- test = K2SpeechRecognitionDataset(
- input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
- if self.args.on_the_fly_feats
- else eval(self.args.input_strategy)(),
- return_cuts=self.args.return_cuts,
- )
- sampler = DynamicBucketingSampler(
- cuts,
- max_duration=self.args.max_duration,
- shuffle=False,
- )
- logging.debug("About to create test dataloader")
- test_dl = DataLoader(
- test,
- batch_size=None,
- sampler=sampler,
- num_workers=self.args.num_workers,
- )
- return test_dl
-
- @lru_cache()
- def train_cuts(self) -> CutSet:
- logging.info(f"About to get train {self.args.subset} cuts")
- if self.args.subset == "XL":
- filenames = glob.glob(
- f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
- )
- pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
- idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
- idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
- sorted_filenames = [f[1] for f in idx_filenames]
- logging.info(
- f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
- )
-
- cuts_train = lhotse.combine(
- lhotse.load_manifest_lazy(p) for p in sorted_filenames
- )
- else:
- path = (
- self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
- )
- cuts_train = CutSet.from_jsonl_lazy(path)
- return cuts_train
-
- @lru_cache()
- def dev_cuts(self) -> CutSet:
- logging.info("About to get dev cuts")
- cuts_valid = load_manifest_lazy(
- self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
- )
- if self.args.small_dev:
- return cuts_valid.subset(first=1000)
- else:
- return cuts_valid
-
- @lru_cache()
- def test_cuts(self) -> CutSet:
- logging.info("About to get test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
- )
-
- @lru_cache()
- def fsc_train_cuts(self) -> CutSet:
- logging.info("About to get fluent speech commands train cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
- )
-
- @lru_cache()
- def fsc_valid_cuts(self) -> CutSet:
- logging.info("About to get fluent speech commands valid cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
- )
-
- @lru_cache()
- def fsc_test_small_cuts(self) -> CutSet:
- logging.info("About to get fluent speech commands small test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
- )
-
- @lru_cache()
- def fsc_test_large_cuts(self) -> CutSet:
- logging.info("About to get fluent speech commands large test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
- )
diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..75bc3d45a
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../../ASR/zipformer/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py
index 39d8fc6cd..2d88b6e55 100755
--- a/egs/gigaspeech/KWS/zipformer/train.py
+++ b/egs/gigaspeech/KWS/zipformer/train.py
@@ -97,6 +97,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -961,7 +962,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py
index bf50bf5ea..63a38a4cc 100755
--- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -805,7 +811,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py
index 485ea69c9..406749f22 100755
--- a/egs/ksponspeech/ASR/zipformer/train.py
+++ b/egs/ksponspeech/ASR/zipformer/train.py
@@ -92,6 +92,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -942,7 +943,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py
index 7e0bf5b7b..fc866f83b 100755
--- a/egs/librispeech/ASR/conformer_ctc/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc/decode.py
@@ -667,7 +667,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -707,7 +709,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.method in [
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
index 38b60fcb9..5b3a021ad 100755
--- a/egs/librispeech/ASR/conformer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -271,7 +271,7 @@ def main():
use_feat_batchnorm=params.use_feat_batchnorm,
)
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -351,7 +351,9 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -362,7 +364,9 @@ def main():
"attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py
index 0b271a51c..349e8f02d 100755
--- a/egs/librispeech/ASR/conformer_ctc2/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc2/decode.py
@@ -774,7 +774,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -814,7 +816,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.method in [
diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py
index c4a13b101..14c132ada 100755
--- a/egs/librispeech/ASR/conformer_ctc2/train.py
+++ b/egs/librispeech/ASR/conformer_ctc2/train.py
@@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -420,7 +421,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -629,7 +630,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -676,7 +677,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -965,7 +966,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py
index e6327bb5e..cf58fd18d 100755
--- a/egs/librispeech/ASR/conformer_ctc3/decode.py
+++ b/egs/librispeech/ASR/conformer_ctc3/decode.py
@@ -868,7 +868,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -907,7 +909,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index 19b26361e..f8e3fa43b 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -334,7 +334,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -345,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
index a0cdfcf03..e528b2cb8 100755
--- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -290,7 +290,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -386,7 +386,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -397,7 +399,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
index a2f1125ca..64e77f421 100755
--- a/egs/librispeech/ASR/conformer_ctc3/train.py
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -493,7 +494,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -694,7 +695,7 @@ def train_one_epoch(
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler],
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -743,7 +744,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1004,7 +1005,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py
index 74f6e73fa..01fcf0685 100755
--- a/egs/librispeech/ASR/conformer_mmi/decode.py
+++ b/egs/librispeech/ASR/conformer_mmi/decode.py
@@ -574,7 +574,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False
+ )
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@@ -609,7 +611,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
+ )
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
index ca21bd6bf..fc33f9512 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py
@@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -93,7 +92,14 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -560,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -727,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1002,7 +1008,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
index 23ddb6bec..b00cc6cc6 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py
@@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -93,7 +92,14 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -560,7 +566,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -727,7 +733,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -772,7 +778,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1001,7 +1007,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py
index d19d50ae6..ec39d5b36 100755
--- a/egs/librispeech/ASR/local/compile_hlg.py
+++ b/egs/librispeech/ASR/local/compile_hlg.py
@@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
- L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
- d = torch.load(f"data/lm/{lm}.pt")
+ d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py
index 709b14070..bd25cfa29 100755
--- a/egs/librispeech/ASR/local/compile_lg.py
+++ b/egs/librispeech/ASR/local/compile_lg.py
@@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
- L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
- d = torch.load(f"data/lm/{lm}.pt")
+ d = torch.load(f"data/lm/{lm}.pt", weights_only=False)
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
index 856c9d945..8c75eb871 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py
@@ -750,7 +750,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py
index e7bad7ed8..9f148b348 100644
--- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -156,7 +156,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -192,7 +192,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
index 42c3a5d7f..f29d1d9db 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
@@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
index feb81d500..e23da3b56 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py
@@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -82,9 +81,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -521,7 +522,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -717,7 +718,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -763,7 +764,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1023,7 +1024,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
index 1a724830b..cfbbb334c 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
@@ -935,7 +935,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py
index 4957d14b1..5aafe10af 100644
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
index dcff088e2..888f9931e 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
@@ -241,7 +241,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
index 4fc4fa7f8..1b31b5485 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py
@@ -74,7 +74,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -90,9 +89,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -560,7 +561,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -772,7 +773,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -848,7 +849,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1176,7 +1177,7 @@ def run(rank, world_size, args):
else:
logging.info("Skip scan_pessimistic_batches_for_oom")
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
index a2b4f9e1a..e25b79e2e 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py
@@ -815,7 +815,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
index e39637bd8..619e783b0 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
@@ -239,7 +239,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
index 2c1cef3a3..e169b499f 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
@@ -66,7 +66,6 @@ from lstm import RNN
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -82,9 +81,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -551,7 +552,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -747,7 +748,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -793,7 +794,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1067,7 +1068,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py
index ca8c28af1..3b6ce9b89 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/model.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/model.py
@@ -21,7 +21,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -141,7 +141,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -176,7 +176,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py
index 5b595c76c..5850555cd 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py
@@ -10,9 +10,11 @@ from typing import Optional, Tuple
import torch
from scaling import ScaledLinear
from torch import Tensor, nn
-from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd
+from torch.cuda.amp import custom_bwd, custom_fwd
from torch_scheduled_sampling import sample_combined
+from icefall.utils import create_grad_scaler, torch_autocast
+
# The main exports of this file are the module KnowledgeBaseLookup and the
# function create_knowledge_base.
@@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast():
optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04)
m = m.to(device)
- scaler = GradScaler(enabled=True)
+ scaler = create_grad_scaler(enabled=True)
start = timeit.default_timer()
for epoch in range(150):
for n, (x, y) in enumerate(train_pairs):
y_out = m(x)
- with torch.cuda.amp.autocast(enabled=True):
+ with torch_autocast(enabled=True):
loss = ((y_out - y) ** 2).mean() * 100.0
if n % 10 == 0 and epoch % 10 == 0:
print(f"Epoch {epoch}, batch {n}, loss {loss.item()}")
diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py
index 931341cc4..0611fd8cb 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/train.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/train.py
@@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ create_grad_scaler,
+ AttributeDict,
+ MetricsTracker,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -453,7 +459,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -608,7 +614,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@@ -650,7 +656,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -868,7 +874,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
index 2b872f1d5..2af8f3f4c 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py
@@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from noam import Noam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -68,7 +67,14 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
def add_model_arguments(parser: argparse.ArgumentParser):
@@ -496,7 +502,7 @@ def save_checkpoint(
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, and training stats to file.
@@ -650,7 +656,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -693,7 +699,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -939,7 +945,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
index 3c4500087..6d1da7440 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py
@@ -741,7 +741,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
index 66c84b2a9..f1d16749c 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
@@ -1347,7 +1347,10 @@ def modified_beam_search(
(
context_score,
new_context_state,
- ) = context_graph.forward_one_step(hyp.context_state, new_token)
+ _,
+ ) = context_graph.forward_one_step(
+ hyp.context_state, new_token, strict_mode=False
+ )
new_log_prob = topk_log_probs[k] + context_score
@@ -2853,7 +2856,10 @@ def modified_beam_search_LODR(
(
context_score,
new_context_state,
- ) = context_graph.forward_one_step(hyp.context_state, new_token)
+ _,
+ ) = context_graph.forward_one_step(
+ hyp.context_state, new_token, strict_mode=False
+ )
ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
index c57514193..5a4a74ebb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py
@@ -754,7 +754,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py
index 272d06c37..6a69332aa 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -157,7 +157,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -193,7 +193,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
index 6923f4d40..e6ddcab25 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -265,7 +265,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
index 6c19f2cb0..ce6c89614 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py
@@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -91,9 +90,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -523,7 +524,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -716,7 +717,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@@ -759,7 +760,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1000,7 +1001,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
index 7c62bfa58..18a3792b0 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py
@@ -921,7 +921,7 @@ def load_ngram_LM(
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
- d = torch.load(pt_file, map_location=device)
+ d = torch.load(pt_file, map_location=device, weights_only=False)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
@@ -1101,7 +1101,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method in [
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py
index d45f6dadc..fbc4db921 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -195,7 +195,7 @@ class Transducer(nn.Module):
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -231,7 +231,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
index 05e6a6fba..19143fb5d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
@@ -274,7 +274,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
index fdafa5a87..50670d1b2 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py
@@ -74,7 +74,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -87,9 +86,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -546,7 +547,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -755,7 +756,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
- scaler: GradScaler,
+ scaler: "GradScaler",
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@@ -827,7 +828,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1126,7 +1127,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index 5195a4ef6..925c01c7b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -913,7 +913,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
index 875b03f7f..c35f52309 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py
@@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -96,9 +95,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -548,7 +549,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -744,7 +745,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -789,7 +790,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1047,7 +1048,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
index 7a3e63218..404d7a3d3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py
@@ -972,7 +972,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
index a9ce75a7b..9e2669379 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -238,7 +238,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
index 66dc5f991..6f9f92623 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py
@@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -84,9 +83,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -571,7 +572,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -768,7 +769,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -814,7 +815,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1078,7 +1079,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py
index daadb70c9..a5d2457f9 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -185,7 +185,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -220,7 +220,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
index 8f033cb9a..35ee74f15 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py
@@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@@ -96,9 +95,11 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
display_and_save_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -519,7 +520,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -736,7 +737,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -781,7 +782,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1039,7 +1040,7 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
- scaler = GradScaler(enabled=params.use_fp16)
+ scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py
index 3bca7db2c..4f3fbaa81 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py
@@ -348,7 +348,9 @@ class CodebookIndexExtractor:
num_codebooks=self.params.num_codebooks,
codebook_size=256,
)
- quantizer.load_state_dict(torch.load(self.quantizer_file_path))
+ quantizer.load_state_dict(
+ torch.load(self.quantizer_file_path, weights_only=False)
+ )
quantizer.to(self.params.device)
return quantizer
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
index 27ef0a244..949a497ce 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
@@ -289,7 +289,7 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
index eb8841cc4..048de7bb9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py
@@ -910,7 +910,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
index 7095c3cc8..da1bf17fc 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py
@@ -813,7 +813,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
index e7546ec45..d3d996b4a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
@@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
filter_uneven_sized_batch,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -635,7 +636,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
- checkpoint = torch.load(ckpt, map_location="cpu")
+ checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:
@@ -678,7 +679,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -857,7 +858,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -903,7 +904,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1219,7 +1220,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py
index add0e6a18..ed990b689 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py
@@ -23,7 +23,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -150,7 +150,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
index 4bf11ac24..fabda3aaa 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
@@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
index 30a737061..5a317083c 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
@@ -28,6 +28,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Embedding as ScaledEmbedding
+from icefall.utils import torch_autocast
+
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
@@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
def backward(ctx, x_grad: Tensor):
(x_orig,) = ctx.saved_tensors
with torch.enable_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module):
):
return _no_op(x)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
eps = 1.0e-20
orig_x = x
x = x.to(torch.float32)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index 436ec53b4..f94da9788 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
filter_uneven_sized_batch,
setup_logger,
str2bool,
symlink_or_copy,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -581,7 +582,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -763,7 +764,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1106,7 +1107,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
index cbde2a2e4..ee05627ae 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
@@ -44,7 +44,7 @@ from scaling import (
from torch import Tensor, nn
from icefall.dist import get_rank
-from icefall.utils import is_jit_tracing, make_pad_mask
+from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast
class Zipformer(EncoderInterface):
@@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = (
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
index 629bec058..3b181bf23 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py
@@ -633,7 +633,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -672,7 +674,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
index 7641fa5af..9e16c3fd7 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py
@@ -786,7 +786,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
index d1b7eec65..f7dd07f8d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
index a6e919e2f..f1ab2a3ec 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py
@@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -150,7 +150,7 @@ class Transducer(nn.Module):
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -185,7 +185,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
index 323ba2642..a13952dfa 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
@@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
index 1e638aa7d..32242c94e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -365,7 +365,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -376,7 +378,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
index b35e56abc..a26f11c82 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -588,7 +589,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -787,7 +788,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -833,7 +834,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1128,7 +1129,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py
index fa7144f0f..3af3ada2c 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py
@@ -624,7 +624,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -663,7 +665,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
index e2f08abc6..233f00236 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py
@@ -808,7 +808,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py
index e497787d3..025b146b9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py
@@ -786,7 +786,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
index 80604ef4a..70d9841bf 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
@@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
index 0582b289f..bf0faf9f1 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
@@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
-from icefall.utils import add_sos, make_pad_mask
+from icefall.utils import add_sos, make_pad_mask, torch_autocast
class Transducer(nn.Module):
@@ -178,7 +178,7 @@ class Transducer(nn.Module):
am = self.simple_am_proj(encoder_out_fr)
lm = self.simple_lm_proj(decoder_out)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -213,7 +213,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
index a82f3562b..9ceec5f5a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
@@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
index b98756a54..431760f9a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
@@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -362,7 +362,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -373,7 +375,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
index c2d877a93..5585d74de 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
@@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -82,9 +81,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -581,7 +582,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -778,7 +779,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -822,7 +823,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1118,7 +1119,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py
index 02029c108..aa2fe8e38 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -936,7 +936,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
index aa2dd17fb..f98851f50 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
@@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
index 8bd00bbef..4d8a2644d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -82,7 +81,14 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -597,7 +603,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -764,7 +770,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -810,7 +816,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1124,7 +1130,7 @@ def run(rank, world_size, args):
# params=params,
# )
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
index c7e45564f..640d72b67 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
@@ -44,7 +44,7 @@ from scaling import (
)
from torch import Tensor, nn
-from icefall.utils import make_pad_mask, subsequent_chunk_mask
+from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast
def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]:
@@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz = n // num_heads
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_output = attn_output.to(torch.float32)
attn_weights_entropy = (
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py
index 35158ced4..61c1a9663 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py
@@ -768,7 +768,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
index a4f52ad7f..e95bb3357 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py
@@ -788,7 +788,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
index da5e144c9..4b97575e6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
@@ -70,7 +70,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -86,7 +85,14 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -615,7 +621,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -795,7 +801,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1218,7 +1224,7 @@ def run(rank, world_size, args):
# params=params,
# )
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
index e07777c9f..3cad83a0b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py
@@ -747,7 +747,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py
index 39a360796..e06594c27 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py
@@ -24,7 +24,7 @@ import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class Transducer(nn.Module):
@@ -172,7 +172,7 @@ class Transducer(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -207,7 +207,7 @@ class Transducer(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
index c29b8d8c9..693db2beb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
@@ -247,7 +247,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
index 646f30ca1..ad14ec9dc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
@@ -75,7 +75,6 @@ from librispeech import LibriSpeech
from model import Transducer
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -91,7 +90,14 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -608,7 +614,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -790,7 +796,7 @@ def train_one_epoch(
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -866,7 +872,7 @@ def train_one_epoch(
libri = is_libri(batch["supervisions"]["cut"][0])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1219,7 +1225,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 1b52aa8b5..283252a46 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -311,8 +311,7 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=self.args.num_buckets * 2000,
- shuffle_buffer_size=self.args.num_buckets * 5000,
+ buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
index 92529e06c..db12ab827 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
@@ -398,7 +398,9 @@ def main():
logging.info(f"device: {device}")
- HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
assert HLG.requires_grad is False
@@ -428,7 +430,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False
+ )
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index b3dfab64a..4ad7cb016 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -167,13 +167,15 @@ def main():
subsampling_factor=params.subsampling_factor,
)
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -181,7 +183,9 @@ def main():
if params.method == "whole-lattice-rescoring":
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
index cda03b56e..ec700626a 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
@@ -589,7 +589,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -628,7 +630,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
index cc4471e2b..1b329e8f3 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py
@@ -663,7 +663,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
index 92dea3aa1..4b234a328 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
@@ -347,7 +347,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -358,7 +360,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
index 5c6956324..9714aa537 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
@@ -249,7 +249,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
index 7698ada79..a2ea1dd06 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
@@ -286,7 +286,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -365,7 +365,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -376,7 +378,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
index 1bfd071de..368bd20fa 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
@@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
@@ -72,9 +71,11 @@ from icefall.lexicon import UniqLexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@@ -550,7 +551,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -757,7 +758,7 @@ def train_one_epoch(
phone_lexicon: UniqLexicon,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -809,7 +810,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1092,7 +1093,7 @@ def run(rank, world_size, args):
# params=params,
# )
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index 4d9bbf4b1..06b1c05b9 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -222,7 +222,7 @@ def main():
logging.info("Creating model")
model = get_transducer_model(params)
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py
index fe9347b95..e17407c5f 100755
--- a/egs/librispeech/ASR/zipformer/ctc_decode.py
+++ b/egs/librispeech/ASR/zipformer/ctc_decode.py
@@ -947,7 +947,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -987,7 +989,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.decoding_method in [
diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py
index cbfb3728e..6462d22f8 100755
--- a/egs/librispeech/ASR/zipformer/decode.py
+++ b/egs/librispeech/ASR/zipformer/decode.py
@@ -1013,7 +1013,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
index 3cda337c0..a4da83949 100755
--- a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
@@ -1049,7 +1049,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
index 99685f2fe..413b5bb1e 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
@@ -153,6 +153,13 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
+ parser.add_argument(
+ "--fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to export models in fp16",
+ )
+
add_model_arguments(parser)
return parser
@@ -176,6 +183,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
onnx.save(model, filename)
+def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16
+
+ onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
+ onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
+
+
class OnnxModel(nn.Module):
"""A wrapper for encoder_embed, Zipformer, and ctc_output layer"""
@@ -407,7 +423,7 @@ def main():
opset_version = 13
logging.info("Exporting ctc model")
- filename = params.exp_dir / f"model.onnx"
+ filename = params.exp_dir / "model.onnx"
export_ctc_model_onnx(
model,
filename,
@@ -420,7 +436,7 @@ def main():
logging.info("Generate int8 quantization models")
- filename_int8 = params.exp_dir / f"model.int8.onnx"
+ filename_int8 = params.exp_dir / "model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
@@ -428,6 +444,10 @@ def main():
weight_type=QuantType.QInt8,
)
+ if params.fp16:
+ filename_fp16 = params.exp_dir / "model.fp16.onnx"
+ export_onnx_fp16(filename, filename_fp16)
+
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
index b1f9d70ff..9a715eefd 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -150,12 +150,35 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
+ parser.add_argument(
+ "--use-whisper-features",
+ type=str2bool,
+ default=False,
+ help="True to use whisper features. Must match the one used in training",
+ )
+
+ parser.add_argument(
+ "--fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to export models in fp16",
+ )
+
+ parser.add_argument(
+ "--use-external-data",
+ type=str2bool,
+ default=False,
+ help="Set it to true for model file size > 2GB",
+ )
+
add_model_arguments(parser)
return parser
-def add_meta_data(filename: str, meta_data: Dict[str, str]):
+def add_meta_data(
+ filename: str, meta_data: Dict[str, str], use_external_data: bool = False
+):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
@@ -164,13 +187,46 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
meta_data:
Key-value pairs.
"""
+ filename = str(filename)
+
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
- onnx.save(model, filename)
+ if use_external_data:
+ # For models file size > 2GB
+ external_filename = Path(filename).stem
+
+ onnx.save(
+ model,
+ filename,
+ save_as_external_data=True,
+ all_tensors_to_one_file=True,
+ location=external_filename + ".weights",
+ )
+ else:
+ onnx.save(model, filename)
+
+
+def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16
+
+ onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
+ onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
+
+
+def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path
+
+ onnx_fp16_model = convert_float_to_float16_model_path(
+ onnx_fp32_path, keep_io_types=True
+ )
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
class OnnxModel(nn.Module):
@@ -285,6 +341,8 @@ def export_streaming_ctc_model_onnx(
encoder_filename: str,
opset_version: int = 11,
dynamic_batch: bool = True,
+ use_whisper_features: bool = False,
+ use_external_data: bool = False,
) -> None:
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
@@ -382,6 +440,10 @@ def export_streaming_ctc_model_onnx(
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
+
+ if use_whisper_features:
+ meta_data["feature"] = "whisper"
+
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
@@ -428,7 +490,11 @@ def export_streaming_ctc_model_onnx(
else {},
)
- add_meta_data(filename=encoder_filename, meta_data=meta_data)
+ add_meta_data(
+ filename=encoder_filename,
+ meta_data=meta_data,
+ use_external_data=use_external_data,
+ )
@torch.no_grad()
@@ -559,12 +625,19 @@ def main():
opset_version = 13
logging.info("Exporting model")
- model_filename = params.exp_dir / f"ctc-{suffix}.onnx"
+
+ if params.use_external_data:
+ model_filename = f"ctc-{suffix}.onnx"
+ else:
+ model_filename = params.exp_dir / f"ctc-{suffix}.onnx"
+
export_streaming_ctc_model_onnx(
model,
- model_filename,
+ str(model_filename),
opset_version=opset_version,
dynamic_batch=params.dynamic_batch == 1,
+ use_whisper_features=params.use_whisper_features,
+ use_external_data=params.use_external_data,
)
logging.info(f"Exported model to {model_filename}")
@@ -574,13 +647,25 @@ def main():
logging.info("Generate int8 quantization models")
+ if params.use_external_data:
+ model_filename_int8 = f"ctc-{suffix}.int8.onnx"
+ else:
model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx"
- quantize_dynamic(
- model_input=model_filename,
- model_output=model_filename_int8,
- op_types_to_quantize=["MatMul"],
- weight_type=QuantType.QInt8,
- )
+
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+ if params.fp16:
+ if params.use_external_data:
+ model_filename_fp16 = f"ctc-{suffix}.fp16.onnx"
+ export_onnx_fp16_large_2gb(model_filename, model_filename_fp16)
+ else:
+ model_filename_fp16 = params.exp_dir / f"ctc-{suffix}.fp16.onnx"
+ export_onnx_fp16(model_filename, model_filename_fp16)
if __name__ == "__main__":
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
index 43ec5d59b..daeb86f6a 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
@@ -176,12 +176,47 @@ def get_parser():
help="Whether to export models in fp16",
)
+ parser.add_argument(
+ "--use-whisper-features",
+ type=str2bool,
+ default=False,
+ help="True to use whisper features. Must match the one used in training",
+ )
+
+ parser.add_argument(
+ "--use-external-data",
+ type=str2bool,
+ default=False,
+ help="Set it to true for model file size > 2GB",
+ )
+
add_model_arguments(parser)
return parser
-def add_meta_data(filename: str, meta_data: Dict[str, str]):
+def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16
+
+ onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
+ onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
+
+
+def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path
+
+ onnx_fp16_model = convert_float_to_float16_model_path(
+ onnx_fp32_path, keep_io_types=True
+ )
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
+
+
+def add_meta_data(
+ filename: str, meta_data: Dict[str, str], use_external_data: bool = False
+):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
@@ -196,7 +231,19 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
meta.key = key
meta.value = value
- onnx.save(model, filename)
+ if use_external_data:
+ # For models file size > 2GB
+ external_filename = Path(filename).stem
+
+ onnx.save(
+ model,
+ filename,
+ save_as_external_data=True,
+ all_tensors_to_one_file=True,
+ location=external_filename + ".weights",
+ )
+ else:
+ onnx.save(model, filename)
class OnnxEncoder(nn.Module):
@@ -357,6 +404,8 @@ def export_encoder_model_onnx(
opset_version: int = 11,
feature_dim: int = 80,
dynamic_batch: bool = True,
+ use_whisper_features: bool = False,
+ use_external_data: bool = False,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
@@ -456,6 +505,9 @@ def export_encoder_model_onnx(
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
+ if use_whisper_features:
+ meta_data["feature"] = "whisper"
+
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
@@ -502,7 +554,11 @@ def export_encoder_model_onnx(
else {},
)
- add_meta_data(filename=encoder_filename, meta_data=meta_data)
+ add_meta_data(
+ filename=encoder_filename,
+ meta_data=meta_data,
+ use_external_data=use_external_data,
+ )
def export_decoder_model_onnx(
@@ -751,13 +807,18 @@ def main():
opset_version = 13
logging.info("Exporting encoder")
- encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
+ if params.use_external_data:
+ encoder_filename = f"encoder-{suffix}.onnx"
+ else:
+ encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
- encoder_filename,
+ str(encoder_filename),
opset_version=opset_version,
feature_dim=params.feature_dim,
dynamic_batch=params.dynamic_batch == 1,
+ use_whisper_features=params.use_whisper_features,
+ use_external_data=params.use_external_data,
)
logging.info(f"Exported encoder to {encoder_filename}")
@@ -782,24 +843,20 @@ def main():
logging.info(f"Exported joiner to {joiner_filename}")
if params.fp16:
- from onnxconverter_common import float16
-
logging.info("Generate fp16 models")
- encoder = onnx.load(encoder_filename)
- encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
- encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
- onnx.save(encoder_fp16, encoder_filename_fp16)
+ if params.use_external_data:
+ encoder_filename_fp16 = f"encoder-{suffix}.fp16.onnx"
+ export_onnx_fp16_large_2gb(encoder_filename, encoder_filename_fp16)
+ else:
+ encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
+ export_onnx_fp16(encoder_filename, encoder_filename_fp16)
- decoder = onnx.load(decoder_filename)
- decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
- onnx.save(decoder_fp16, decoder_filename_fp16)
+ export_onnx_fp16(decoder_filename, decoder_filename_fp16)
- joiner = onnx.load(joiner_filename)
- joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
- onnx.save(joiner_fp16, joiner_filename_fp16)
+ export_onnx_fp16(joiner_filename, joiner_filename_fp16)
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
@@ -807,7 +864,11 @@ def main():
if params.enable_int8_quantization:
logging.info("Generate int8 quantization models")
- encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
+ if params.use_external_data:
+ encoder_filename_int8 = f"encoder-{suffix}.int8.onnx"
+ else:
+ encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
+
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py
index a56a7a3e6..03c7d6f82 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx.py
@@ -70,7 +70,6 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
-from onnxconverter_common import float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@@ -182,6 +181,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
onnx.save(model, filename)
+def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
+ import onnxmltools
+ from onnxmltools.utils.float16_converter import convert_float_to_float16
+
+ onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
+ onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
+ onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
+
+
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
@@ -595,20 +603,14 @@ def main():
if params.fp16:
logging.info("Generate fp16 models")
- encoder = onnx.load(encoder_filename)
- encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
- onnx.save(encoder_fp16, encoder_filename_fp16)
+ export_onnx_fp16(encoder_filename, encoder_filename_fp16)
- decoder = onnx.load(decoder_filename)
- decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
- onnx.save(decoder_fp16, decoder_filename_fp16)
+ export_onnx_fp16(decoder_filename, decoder_filename_fp16)
- joiner = onnx.load(joiner_filename)
- joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
- onnx.save(joiner_fp16, joiner_filename_fp16)
+ export_onnx_fp16(joiner_filename, joiner_filename_fp16)
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py
index 2ff631914..94e8b273a 100755
--- a/egs/librispeech/ASR/zipformer/finetune.py
+++ b/egs/librispeech/ASR/zipformer/finetune.py
@@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@@ -95,11 +94,13 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
+ create_grad_scaler,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -140,8 +141,8 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
type=str2bool,
default=False,
help="""
- Whether to adapt. If true, we will mix 5% of the new data
- with 95% of the original data to fine-tune. This is useful
+ Whether to adapt. If true, we will mix 5%% of the new data
+ with 95%% of the original data to fine-tune. This is useful
if you want to maintain the performance on the original domain
""",
)
@@ -765,7 +766,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
- checkpoint = torch.load(ckpt, map_location="cpu")
+ checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:
@@ -808,7 +809,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -985,7 +986,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -1049,7 +1050,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1134,7 +1135,7 @@ def train_one_epoch(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
- f"lr: {cur_lr:.2e}, "
+ f"lr: {cur_lr: .2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
@@ -1373,7 +1374,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
index fcd07ae34..d1978df52 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
@@ -346,7 +346,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -357,7 +359,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py
index c7dbe1e0a..6ef250819 100644
--- a/egs/librispeech/ASR/zipformer/model.py
+++ b/egs/librispeech/ASR/zipformer/model.py
@@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment
from scaling import ScaledLinear
-from icefall.utils import add_sos, make_pad_mask, time_warp
+from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
class AsrModel(nn.Module):
@@ -210,10 +210,10 @@ class AsrModel(nn.Module):
)
# Compute consistency regularization loss
- exchanged_targets = ctc_output.detach().chunk(2, dim=0)
- exchanged_targets = torch.cat(
- [exchanged_targets[1], exchanged_targets[0]], dim=0
- ) # exchange: [x1, x2] -> [x2, x1]
+ batch_size = ctc_output.shape[0]
+ assert batch_size % 2 == 0, batch_size
+ # exchange: [x1, x2] -> [x2, x1]
+ exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
@@ -285,7 +285,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -320,7 +320,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py
index 9f3571b08..65ea7c7f2 100755
--- a/egs/librispeech/ASR/zipformer/pretrained.py
+++ b/egs/librispeech/ASR/zipformer/pretrained.py
@@ -289,7 +289,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
index 4341ef61f..90a6ff5b8 100755
--- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
@@ -305,7 +305,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -389,7 +389,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@@ -400,7 +402,9 @@ def main():
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
- G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
+ G = k2.Fsa.from_dict(
+ torch.load(params.G, map_location="cpu", weights_only=False)
+ )
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py
index d345c2931..22aa1b1ca 100644
--- a/egs/librispeech/ASR/zipformer/scaling.py
+++ b/egs/librispeech/ASR/zipformer/scaling.py
@@ -26,6 +26,8 @@ import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
+from icefall.utils import torch_autocast
+
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
@@ -160,8 +162,10 @@ class PiecewiseLinear(object):
extra_x_vals.append(extra_x_val)
if len(extra_x_vals) > 0:
x_vals = sorted(set(x_vals + extra_x_vals))
- y_vals1 = [self(x) for x in x_vals]
- y_vals2 = [p(x) for x in x_vals]
+
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+
return (
PiecewiseLinear(*zip(x_vals, y_vals1)),
PiecewiseLinear(*zip(x_vals, y_vals2)),
@@ -306,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@@ -759,7 +763,7 @@ class BalancerFunction(torch.autograd.Function):
try:
with torch.enable_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
@@ -1014,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try:
with torch.enable_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@@ -1353,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
@@ -1401,9 +1405,9 @@ class SwooshL(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
- return k2.swoosh_l_forward(x)
+ return k2.swoosh_l_forward(x).to(x.dtype)
else:
- return k2.swoosh_l(x)
+ return k2.swoosh_l(x).to(x.dtype)
# return SwooshLFunction.apply(x)
@@ -1430,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
@@ -1475,9 +1479,9 @@ class SwooshR(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
if not x.requires_grad:
- return k2.swoosh_r_forward(x)
+ return k2.swoosh_r_forward(x).to(x.dtype)
else:
- return k2.swoosh_r(x)
+ return k2.swoosh_r(x).to(x.dtype)
# return SwooshRFunction.apply(x)
diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py
index f8864d58b..42ae9b9f2 100755
--- a/egs/librispeech/ASR/zipformer/train.py
+++ b/egs/librispeech/ASR/zipformer/train.py
@@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@@ -98,9 +97,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -829,7 +830,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -1034,7 +1035,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
@@ -1101,9 +1102,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(
- enabled=params.use_autocast, dtype=params.dtype
- ):
+ with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1449,7 +1448,7 @@ def run(rank, world_size, args):
spec_augment=spec_augment,
)
- scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(
- enabled=params.use_autocast, dtype=params.dtype
- ):
+ with torch_autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py
index 2a0ae0129..e83a89400 100644
--- a/egs/librispeech/ASR/zipformer/zipformer.py
+++ b/egs/librispeech/ASR/zipformer/zipformer.py
@@ -47,6 +47,8 @@ from scaling import (
)
from torch import Tensor, nn
+from icefall.utils import torch_autocast
+
class Zipformer2(EncoderInterface):
"""
@@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)
diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py
index 91533be8d..e8798aed6 100755
--- a/egs/librispeech/ASR/zipformer_adapter/decode.py
+++ b/egs/librispeech/ASR/zipformer_adapter/decode.py
@@ -1005,7 +1005,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
index bbc582f50..66c401761 100755
--- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py
@@ -1050,7 +1050,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py
index 3511590da..fcd7272e9 100755
--- a/egs/librispeech/ASR/zipformer_adapter/train.py
+++ b/egs/librispeech/ASR/zipformer_adapter/train.py
@@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -762,7 +763,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
- checkpoint = torch.load(ckpt, map_location="cpu")
+ checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:
@@ -805,7 +806,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -982,7 +983,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -1052,7 +1053,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1397,7 +1398,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py
index 8e2dfdd72..8bc163db5 100644
--- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py
+++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py
@@ -50,6 +50,8 @@ from scaling import (
)
from torch import Tensor, nn
+from icefall.utils import torch_autocast
+
class Zipformer2(EncoderInterface):
"""
@@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)
diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py
index 7f605e2c8..b9eed099c 100755
--- a/egs/librispeech/ASR/zipformer_ctc/decode.py
+++ b/egs/librispeech/ASR/zipformer_ctc/decode.py
@@ -679,7 +679,9 @@ def main():
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
- torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
+ torch.load(
+ f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False
+ )
)
assert HLG.requires_grad is False
@@ -719,7 +721,9 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
- d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
+ d = torch.load(
+ params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
if params.method in [
diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py
index 60112a84e..bd3bfa332 100755
--- a/egs/librispeech/ASR/zipformer_ctc/train.py
+++ b/egs/librispeech/ASR/zipformer_ctc/train.py
@@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, LRScheduler, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
@@ -65,7 +64,14 @@ from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ create_grad_scaler,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
@@ -533,7 +539,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -687,7 +693,7 @@ def train_one_epoch(
graph_compiler: BpeCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -726,7 +732,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -987,7 +993,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
index 4d93a905f..acc814a00 100755
--- a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
+++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
@@ -1050,7 +1050,7 @@ def main():
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
- torch.load(lg_filename, map_location=device)
+ torch.load(lg_filename, map_location=device, weights_only=False)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py
index 3f36f229f..c26a2f5cc 100755
--- a/egs/librispeech/ASR/zipformer_lora/finetune.py
+++ b/egs/librispeech/ASR/zipformer_lora/finetune.py
@@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@@ -96,9 +95,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -775,7 +776,7 @@ def load_model_params(
"""
logging.info(f"Loading checkpoint from {ckpt}")
- checkpoint = torch.load(ckpt, map_location="cpu")
+ checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
# if module list is empty, load the whole model from ckpt
if not init_modules:
@@ -818,7 +819,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -995,7 +996,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dls: torch.utils.data.DataLoader,
valid_sets: List[str],
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -1065,7 +1066,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1406,7 +1407,7 @@ def run(rank, world_size, args):
# params=params,
# )
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py
index 8d7aa8027..1347570df 100644
--- a/egs/librispeech/ASR/zipformer_lora/scaling.py
+++ b/egs/librispeech/ASR/zipformer_lora/scaling.py
@@ -27,6 +27,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
+from icefall.utils import torch_autocast
+
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
@@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
@@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function):
try:
with torch.enable_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
@@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
try:
with torch.enable_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
@@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function):
coeff = -0.08
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
@@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py
index 9ab214e86..2b83d58ef 100755
--- a/egs/librispeech/ASR/zipformer_lora/train.py
+++ b/egs/librispeech/ASR/zipformer_lora/train.py
@@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
@@ -94,9 +93,11 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -707,7 +708,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -883,7 +884,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -947,7 +948,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1252,7 +1253,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py
index 43865609a..b84b1c32a 100644
--- a/egs/librispeech/ASR/zipformer_lora/zipformer.py
+++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py
@@ -49,6 +49,8 @@ from scaling import (
)
from torch import Tensor, nn
+from icefall.utils import torch_autocast
+
class Zipformer2(EncoderInterface):
"""
@@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)
diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py
index 33c0bf199..bd3ce21f5 100755
--- a/egs/librispeech/ASR/zipformer_mmi/decode.py
+++ b/egs/librispeech/ASR/zipformer_mmi/decode.py
@@ -569,7 +569,9 @@ def main():
if params.decoding_method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
- LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+ LG = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device, weights_only=False)
+ )
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
@@ -602,7 +604,11 @@ def main():
torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt")
else:
logging.info(f"Loading pre-compiled {order}gram.pt")
- d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+ d = torch.load(
+ params.lang_dir / f"{order}gram.pt",
+ map_location=device,
+ weights_only=False,
+ )
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
index 6990c90a0..d5667cafa 100755
--- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
+++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
@@ -308,7 +308,9 @@ def main():
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
- LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+ LG = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device, weights_only=False)
+ )
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
@@ -317,7 +319,9 @@ def main():
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
- d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+ d = torch.load(
+ params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G
diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
index 1e7afc777..ca860b877 100755
--- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py
+++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
@@ -269,7 +269,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
@@ -331,7 +331,9 @@ def main():
if method == "nbest-rescoring-LG":
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
- LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device))
+ LG = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device, weights_only=False)
+ )
LG = k2.Fsa.from_fsas([LG]).to(device)
LG.lm_scores = LG.scores.clone()
LM = LG
@@ -340,7 +342,9 @@ def main():
assert order in ("3", "4")
order = int(order)
logging.info(f"Loading pre-compiled {order}gram.pt")
- d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device)
+ d = torch.load(
+ params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False
+ )
G = k2.Fsa.from_dict(d)
G.lm_scores = G.scores.clone()
LM = G
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
index c1785a328..e0ca0a6a5 100755
--- a/egs/librispeech/ASR/zipformer_mmi/train.py
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed
from model import CTCModel
from optim import Eden, ScaledAdam
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@@ -87,9 +86,11 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
MetricsTracker,
+ create_grad_scaler,
encode_supervisions,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -514,7 +515,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@@ -696,7 +697,7 @@ def train_one_epoch(
mmi_graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
+ scaler: "GradScaler",
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@@ -744,7 +745,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1037,7 +1038,7 @@ def run(rank, world_size, args):
params=params,
)
- scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py
index 17daa3c9d..0080513f3 100644
--- a/egs/librispeech/SSL/hubert/finetune.py
+++ b/egs/librispeech/SSL/hubert/finetune.py
@@ -86,6 +86,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py
index 2723cc770..1ff2b03c0 100644
--- a/egs/librispeech/SSL/hubert/finetune_ce.py
+++ b/egs/librispeech/SSL/hubert/finetune_ce.py
@@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
+ torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@@ -816,7 +817,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py
index 46a968b69..2c2077376 100644
--- a/egs/librispeech/SSL/hubert/model.py
+++ b/egs/librispeech/SSL/hubert/model.py
@@ -24,7 +24,7 @@ import torch
import torch.nn as nn
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module):
@@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py
index f183d90fd..240cd2c0d 100644
--- a/egs/librispeech/SSL/hubert/pretrain.py
+++ b/egs/librispeech/SSL/hubert/pretrain.py
@@ -80,6 +80,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py
index 94948695d..12f95c16f 100644
--- a/egs/librispeech/SSL/hubert/pretrain_ce.py
+++ b/egs/librispeech/SSL/hubert/pretrain_ce.py
@@ -80,6 +80,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -644,7 +645,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py
index c907b41c5..5bebf60f0 100644
--- a/egs/librispeech/SSL/zipformer/finetune.py
+++ b/egs/librispeech/SSL/zipformer/finetune.py
@@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
+ torch_autocast,
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
@@ -1115,7 +1116,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py
index 46a968b69..2c2077376 100644
--- a/egs/librispeech/SSL/zipformer/model.py
+++ b/egs/librispeech/SSL/zipformer/model.py
@@ -24,7 +24,7 @@ import torch
import torch.nn as nn
from scaling import ScaledLinear
-from icefall.utils import add_sos
+from icefall.utils import add_sos, torch_autocast
class AsrModel(nn.Module):
@@ -221,7 +221,7 @@ class AsrModel(nn.Module):
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
@@ -256,7 +256,7 @@ class AsrModel(nn.Module):
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py
index 937fb382e..d772f56d0 100644
--- a/egs/librispeech/SSL/zipformer/pretrain.py
+++ b/egs/librispeech/SSL/zipformer/pretrain.py
@@ -78,6 +78,7 @@ from icefall.utils import (
get_parameter_groups_with_lrs,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -944,7 +945,7 @@ def train_one_epoch(
batch_size = batch["kmeans"].shape[0]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py
index e9eff3357..5071a91a8 100644
--- a/egs/librispeech/SSL/zipformer/zipformer.py
+++ b/egs/librispeech/SSL/zipformer/zipformer.py
@@ -22,6 +22,7 @@ import math
import random
import warnings
from typing import List, Optional, Tuple, Union
+from icefall.utils import torch_autocast
import torch
from encoder_interface import EncoderInterface
@@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
- with torch.cuda.amp.autocast(enabled=False):
+ with torch_autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = (
-((attn_weights + 1.0e-20).log() * attn_weights)
diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py
index 82c68803f..19cce1708 100755
--- a/egs/librispeech/WSASR/conformer_ctc2/train.py
+++ b/egs/librispeech/WSASR/conformer_ctc2/train.py
@@ -84,6 +84,7 @@ from icefall.utils import (
get_texts,
setup_logger,
str2bool,
+ torch_autocast,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -757,7 +758,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py
index b276d0587..a32183bf7 100755
--- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py
+++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py
@@ -79,6 +79,7 @@ from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler
from icefall.utils import (
+ torch_autocast,
AttributeDict,
MetricsTracker,
encode_supervisions_otc,
@@ -758,7 +759,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
index 296f9a4f4..906025b7f 100755
--- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
+++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
@@ -73,6 +73,8 @@ def compute_fbank_ljspeech(num_jobs: int):
f_min=0,
f_max=8000,
)
+ if not torch.cuda.is_available():
+ config.device = "cpu"
prefix = "ljspeech"
suffix = "jsonl.gz"
diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py
index c15db11f7..1b53465c0 100755
--- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py
+++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py
@@ -328,9 +328,14 @@ def main():
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
- text = ""
+ byte_list = []
for i in token_ids:
- text += token_table[i]
+ token = token_table[i]
+ if token.startswith("<0x") and token.endswith(">"):
+ byte_list.append(int(token[3:-1], 16))
+ else:
+ byte_list += list(token.encode("utf-8"))
+ text = bytes(byte_list).decode("utf-8", errors='ignore')
return text.replace("▁", " ").strip()
if params.method == "fast_beam_search":
diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md
index 830c70397..42dce80c5 100644
--- a/egs/speech_llm/ASR_LLM/RESULTS.md
+++ b/egs/speech_llm/ASR_LLM/RESULTS.md
@@ -42,8 +42,8 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
-# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
-huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
+# huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+huggingface-cli download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
@@ -55,9 +55,10 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
- --use-lora False --unfreeze-llm False
+ --use-lora False \
+ --unfreeze-llm False
-# Then we jointly train the projector and LLM LoRA modules.
+# Then, we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
@@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
- --use-lora True --unfreeze-llm True
+ --use-lora True \
+ --unfreeze-llm True \
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
```
@@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
-huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
+huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
-# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
+# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
-huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
+huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
@@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
- --use-lora True --dataset aishell
+ --use-lora True \
+ --dataset aishell
```
diff --git a/egs/speech_llm/ASR_LLM/prepare.sh b/egs/speech_llm/ASR_LLM/prepare.sh
old mode 100644
new mode 100755
index 6f5ed5448..8ca3c1c36
--- a/egs/speech_llm/ASR_LLM/prepare.sh
+++ b/egs/speech_llm/ASR_LLM/prepare.sh
@@ -7,6 +7,9 @@ set -eou pipefail
stage=0
stop_stage=0
+
+. shared/parse_options.sh || exit 1
+
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
@@ -23,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# pip install huggingface_hub['cli']
# for aishell 1
- huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse
+ huggingface-cli download --repo-type dataset --local-dir data yuekai/aishell_whisper_fbank_lhotse
fi
@@ -31,9 +34,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
# for multi-hans-zh
- huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
- huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
- huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
+ huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
+ huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
+ huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@@ -41,6 +44,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# for speechio test sets
mkdir data_speechio
- huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio
+ huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
fi
diff --git a/egs/speech_llm/ASR_LLM/shared b/egs/speech_llm/ASR_LLM/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/speech_llm/ASR_LLM/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
index 882ce4fbf..3036b471e 100755
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py
@@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
-from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
+from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
@@ -357,43 +357,6 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
-
- def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
- """
- Text normalization similar to M2MeT challenge baseline.
- See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
- """
- if normalize == "none":
- return text
- elif normalize == "m2met":
- import re
-
- text = text.replace(" ", "")
- text = text.replace("", "")
- text = text.replace("<%>", "")
- text = text.replace("<->", "")
- text = text.replace("<$>", "")
- text = text.replace("<#>", "")
- text = text.replace("<_>", "")
- text = text.replace("", "")
- text = text.replace("`", "")
- text = text.replace("&", "")
- text = text.replace(",", "")
- if re.search("[a-zA-Z]", text):
- text = text.upper()
- text = text.replace("A", "A")
- text = text.replace("a", "A")
- text = text.replace("b", "B")
- text = text.replace("c", "C")
- text = text.replace("k", "K")
- text = text.replace("t", "T")
- text = text.replace(",", "")
- text = text.replace("丶", "")
- text = text.replace("。", "")
- text = text.replace("、", "")
- text = text.replace("?", "")
- return text
-
results = []
num_cuts = 0
@@ -406,6 +369,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
+ texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
@@ -418,12 +382,8 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
- for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
- ref_text = normalize_text_alimeeting(ref_text)
- ref_words = ref_text.split()
- print(f"ref: {ref_text}")
- print(f"hyp: {''.join(hyp_words)}")
- this_batch.append((cut_id, ref_words, hyp_words))
+ for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_text))
results[lm_scale].extend(this_batch)
@@ -439,40 +399,38 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
- results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
-
- enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
- params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
- if enable_log:
- logging.info(f"The transcripts are stored in {recog_path}")
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
+ logging.info(f"The transcripts are stored in {recog_path}")
- # The following prints out WERs, per-word error statistics and aligned
+ # The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
- params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
- # we compute CER for aishell dataset.
- results_char = []
- for res in results:
- results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
- if enable_log:
- logging.info("Wrote detailed error stats to {}".format(errs_filename))
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
- errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@@ -495,9 +453,13 @@ def main():
params = get_params()
params.update(vars(args))
+
+ params.res_dir = params.exp_dir / f"{params.method}"
+
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(
- f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
+ params.res_dir
+ / f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
)
logging.info("Decoding started")
@@ -574,23 +536,20 @@ def main():
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
- checkpoint = torch.load(
- f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
- )
- assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
- f"{params.exp_dir}/epoch-{epoch}.pt"
+ f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
- filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
- torch.save(avg_checkpoint, filename)
+ # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
+ # torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
- f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
+ map_location="cpu",
)
model.load_state_dict(checkpoint, strict=False)
@@ -643,8 +602,7 @@ def main():
logging.info("Done!")
-torch.set_num_threads(1)
-torch.set_num_interop_threads(1)
-
if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
main()
diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
index 5f224c984..7947a60a5 100755
--- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
+++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
+# 2025 Yifan Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
"""
import argparse
-import copy
import logging
import os
-import random
import warnings
from pathlib import Path
-from shutil import copyfile
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Dict, Optional, Tuple
import deepspeed
-import k2
import torch
-import torch.multiprocessing as mp
import torch.nn as nn
import transformers
import whisper
from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
-from label_smoothing import LabelSmoothingLoss
-from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
-from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset
-from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
-from icefall import diagnostics
from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info
-from icefall.utils import (
- AttributeDict,
- MetricsTracker,
- filter_uneven_sized_batch,
- setup_logger,
- str2bool,
-)
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
DEFAULT_SPEECH_TOKEN = ""
@@ -286,13 +272,6 @@ def compute_loss(
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
- # For the uneven-sized batch, the total duration after padding would possibly
- # cause OOM. Hence, for each batch, which is sorted descendingly by length,
- # we simply drop the last few shortest samples, so that the retained total frames
- # (after padding) would not exceed `allowed_max_frames`:
- # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
- # where `max_frames = max_duration * 1000 // frame_shift_ms`.
- # We set allowed_excess_duration_ratio=0.1.
def preprocess(
messages,
@@ -347,46 +326,6 @@ def compute_loss(
return input_ids, attention_mask, target_ids
- def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
- """
- Text normalization similar to M2MeT challenge baseline.
- See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
- """
- if normalize == "none":
- return text
- elif normalize == "m2met":
- import re
-
- text = text.replace(" ", "")
- text = text.replace("", "")
- text = text.replace("<%>", "")
- text = text.replace("<->", "")
- text = text.replace("<$>", "")
- text = text.replace("<#>", "")
- text = text.replace("<_>", "")
- text = text.replace("", "")
- text = text.replace("`", "")
- text = text.replace("&", "")
- text = text.replace(",", "")
- if re.search("[a-zA-Z]", text):
- text = text.upper()
- text = text.replace("A", "A")
- text = text.replace("a", "A")
- text = text.replace("b", "B")
- text = text.replace("c", "C")
- text = text.replace("k", "K")
- text = text.replace("t", "T")
- text = text.replace(",", "")
- text = text.replace("丶", "")
- text = text.replace("。", "")
- text = text.replace("、", "")
- text = text.replace("?", "")
- return text
-
- max_frames = params.max_duration * 1000 // params.frame_shift_ms
- allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
- batch = filter_uneven_sized_batch(batch, allowed_max_frames)
-
device = next(model.parameters()).device
feature = batch["inputs"]
@@ -397,11 +336,10 @@ def compute_loss(
batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"]
- # remove spaces in texts
- texts = [normalize_text_alimeeting(text) for text in texts]
messages = []
for i, text in enumerate(texts):
+ text = text.replace(" ", "")
message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text},
@@ -516,14 +454,17 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
- model.encoder_projector.train()
+ model.train()
+ model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
- if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@@ -533,6 +474,9 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
+ model.encoder.eval()
+ if not params.unfreeze_llm:
+ model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@@ -648,7 +592,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters():
param.requires_grad = False
- speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn:
@@ -671,7 +614,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm:
for name, param in llm.named_parameters():
param.requires_grad = False
- llm.eval()
else:
if params.use_lora:
lora_config = LoraConfig(
@@ -728,7 +670,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
- assert params.deepspeed and world_size > 1
+ assert params.deepspeed
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
@@ -764,7 +706,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
- # TODO: load sampler state dict
+
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
@@ -806,15 +748,15 @@ def run(rank, world_size, args):
model.save_checkpoint(
save_dir=params.exp_dir,
- tag=f"epoch-{params.cur_epoch}",
+ tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
- f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
- tag=f"epoch-{params.cur_epoch}",
+ f"{params.exp_dir}/epoch-{params.cur_epoch}",
+ tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
@@ -824,7 +766,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
- os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
+ os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
@@ -865,6 +807,7 @@ def main():
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
+ warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
index 170f37767..2455f3630 100755
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py
@@ -422,7 +422,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
unk_id = params.unk_id
- y = convert_texts_into_ids(texts, unk_id, sp=sp)
+ y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py
index 6fed32e81..c6fa34e70 100755
--- a/egs/tedlium3/ASR/transducer_stateless/train.py
+++ b/egs/tedlium3/ASR/transducer_stateless/train.py
@@ -397,7 +397,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
unk_id = params.unk_id
- y = convert_texts_into_ids(texts, unk_id, sp=sp)
+ y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 8b35187b1..2cc06df71 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
- default=30,
+ default=15,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@@ -292,8 +292,7 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=self.args.num_buckets * 2000,
- shuffle_buffer_size=self.args.num_buckets * 5000,
+ buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh
index 8472b8531..0af7c1595 100755
--- a/egs/wenetspeech/KWS/run.sh
+++ b/egs/wenetspeech/KWS/run.sh
@@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
- log "Stage 2: Finetune the model"
+ log "Stage 3: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
@@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
- log "Stage 1: Decode the finetuned model."
+ log "Stage 4: Decode the finetuned model."
export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \
@@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
- log "Stage 2: Export the finetuned model."
+ log "Stage 5: Export the finetuned model."
python ./zipformer/export.py \
--epoch 10 \
diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py
index 340a41231..a628c7e58 100755
--- a/egs/wenetspeech/KWS/zipformer/decode.py
+++ b/egs/wenetspeech/KWS/zipformer/decode.py
@@ -35,7 +35,6 @@ from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph
-from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py
index d19172b38..b1abfd79e 100755
--- a/egs/wenetspeech/KWS/zipformer/finetune.py
+++ b/egs/wenetspeech/KWS/zipformer/finetune.py
@@ -69,6 +69,7 @@ import argparse
import copy
import logging
import warnings
+from functools import partial
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -90,6 +91,7 @@ from train import (
add_training_arguments,
compute_validation_loss,
display_and_save_batch,
+ encode_text,
get_adjusted_batch_count,
get_model,
get_params,
@@ -100,7 +102,6 @@ from train import (
)
from icefall import diagnostics
-from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@@ -110,11 +111,11 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
+ num_tokens,
setup_logger,
str2bool,
text_to_pinyin,
@@ -254,7 +255,6 @@ def load_model_params(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
- graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@@ -289,7 +289,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
- y = graph_compiler.texts_to_ids(texts, sep="/")
+ y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@@ -347,7 +347,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
- graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@@ -418,7 +417,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@@ -436,7 +434,7 @@ def train_one_epoch(
optimizer.zero_grad()
except: # noqa
save_bad_model()
- display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5:
@@ -523,7 +521,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@@ -576,14 +573,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
- lexicon = Lexicon(params.lang_dir)
- graph_compiler = CharCtcTrainingGraphCompiler(
- lexicon=lexicon,
- device=device,
- )
+ token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
- params.blank_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@@ -601,6 +594,9 @@ def run(rank, world_size, args):
if params.continue_finetune:
assert params.start_epoch > 0, params.start_epoch
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
@@ -666,17 +662,10 @@ def run(rank, world_size, args):
else:
train_cuts = wenetspeech.nihaowenwen_train_cuts()
- def encode_text(c: Cut):
- # Text normalize for each sample
- text = c.supervisions[0].text
- text = "/".join(
- text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
- )
- c.supervisions[0].text = text
- return c
+ _encode_text = partial(encode_text, token_table=token_table, params=params)
train_cuts = train_cuts.filter(remove_short_utt)
- train_cuts = train_cuts.map(encode_text)
+ train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@@ -691,7 +680,7 @@ def run(rank, world_size, args):
valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
- valid_cuts = valid_cuts.map(encode_text)
+ valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
@@ -699,7 +688,6 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
- graph_compiler=graph_compiler,
params=params,
)
@@ -724,7 +712,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
- graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@@ -760,6 +747,8 @@ def main():
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+ args.return_cuts = True
world_size = args.world_size
assert world_size >= 1
diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py
index 40960c2ae..5d9d8de36 100755
--- a/egs/wenetspeech/KWS/zipformer/train.py
+++ b/egs/wenetspeech/KWS/zipformer/train.py
@@ -53,6 +53,7 @@ import argparse
import copy
import logging
import warnings
+from functools import partial
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
@@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics
-from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
-from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
+ num_tokens,
setup_logger,
str2bool,
text_to_pinyin,
@@ -776,7 +776,6 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
- graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@@ -811,7 +810,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
- y = graph_compiler.texts_to_ids(texts, sep="/")
+ y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
@@ -859,7 +858,6 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
- graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@@ -872,7 +870,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
@@ -895,7 +892,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
- graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@@ -971,7 +967,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@@ -988,7 +983,7 @@ def train_one_epoch(
optimizer.zero_grad()
except: # noqa
save_bad_model()
- display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5:
@@ -1077,7 +1072,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@@ -1098,6 +1092,20 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
+def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict):
+ text = c.supervisions[0].text
+ tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
+ ids = []
+ for t in tokens:
+ if t in token_table:
+ ids.append(token_table[t])
+ else:
+ logging.warning(f"Text : {text} has OOV token : {t} , encode to ")
+ ids.append(token_table[""])
+ c.supervisions[0].tokens = ids
+ return c
+
+
def run(rank, world_size, args):
"""
Args:
@@ -1130,14 +1138,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
- lexicon = Lexicon(params.lang_dir)
- graph_compiler = CharCtcTrainingGraphCompiler(
- lexicon=lexicon,
- device=device,
- )
+ token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
- params.blank_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@@ -1216,17 +1220,10 @@ def run(rank, world_size, args):
return True
- def encode_text(c: Cut):
- # Text normalize for each sample
- text = c.supervisions[0].text
- text = "/".join(
- text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
- )
- c.supervisions[0].text = text
- return c
+ _encode_text = partial(encode_text, token_table=token_table, params=params)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
- train_cuts = train_cuts.map(encode_text)
+ train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@@ -1240,7 +1237,7 @@ def run(rank, world_size, args):
)
valid_cuts = wenetspeech.valid_cuts()
- valid_cuts = valid_cuts.map(encode_text)
+ valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
@@ -1248,7 +1245,6 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
- graph_compiler=graph_compiler,
params=params,
)
@@ -1273,7 +1269,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
- graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@@ -1307,7 +1302,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
- graph_compiler: CharCtcTrainingGraphCompiler,
) -> None:
"""Display the batch statistics and save the batch into disk.
@@ -1317,8 +1311,6 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
- graph_compiler:
- The compiler to encode texts to ids.
"""
from lhotse.utils import uuid4
@@ -1332,8 +1324,8 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}")
texts = supervisions["text"]
- y = graph_compiler.texts_to_ids(texts)
- num_tokens = sum(len(i) for i in y)
+ tokens = [c.supervisions[0].tokens for c in supervisions["cut"]]
+ num_tokens = sum(len(i) for i in tokens)
logging.info(f"num tokens: {num_tokens}")
@@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
- graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
- graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
- display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ display_and_save_batch(batch, params=params)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@@ -1385,6 +1375,7 @@ def main():
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir)
+ args.return_cuts = True
world_size = args.world_size
assert world_size >= 1
diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py
index e0a94bf08..3de7136ec 100755
--- a/egs/yesno/ASR/local/compile_hlg.py
+++ b/egs/yesno/ASR/local/compile_hlg.py
@@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
- L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False))
logging.info("Loading G.fst.txt")
with open("data/lm/G.fst.txt") as f:
diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py
index f520607af..479e195fa 100755
--- a/egs/yesno/ASR/tdnn/decode.py
+++ b/egs/yesno/ASR/tdnn/decode.py
@@ -271,7 +271,9 @@ def main():
logging.info(f"device: {device}")
- HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
assert HLG.requires_grad is False
diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py
index 6c643c263..2a0879045 100755
--- a/egs/yesno/ASR/tdnn/jit_pretrained.py
+++ b/egs/yesno/ASR/tdnn/jit_pretrained.py
@@ -131,7 +131,9 @@ def main():
model.to(device)
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py
index 968a9e9a8..e6471d2db 100755
--- a/egs/yesno/ASR/tdnn/onnx_pretrained.py
+++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py
@@ -176,7 +176,9 @@ def main():
model = OnnxModel(params.nn_model)
logging.info(f"Loading HLG from {args.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index bea520998..d4f3ae39f 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -148,13 +148,15 @@ def main():
num_classes=params.num_classes,
)
- checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
- HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
+ HLG = k2.Fsa.from_dict(
+ torch.load(params.HLG, map_location="cpu", weights_only=False)
+ )
HLG = HLG.to(device)
logging.info("Constructing Fbank computer")
diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py
index d31ce1301..4ab685684 100644
--- a/icefall/checkpoint.py
+++ b/icefall/checkpoint.py
@@ -27,7 +27,6 @@ import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch import Tensor
-from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@@ -43,7 +42,7 @@ def save_checkpoint(
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
@@ -102,7 +101,7 @@ def load_checkpoint(
model_avg: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False,
) -> Dict[str, Any]:
@@ -110,7 +109,7 @@ def load_checkpoint(
TODO: document it
"""
logging.info(f"Loading checkpoint from {filename}")
- checkpoint = torch.load(filename, map_location="cpu")
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
@@ -163,7 +162,7 @@ def average_checkpoints(
"""
n = len(filenames)
- avg = torch.load(filenames[0], map_location=device)["model"]
+ avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@@ -178,7 +177,9 @@ def average_checkpoints(
uniqued_names = list(uniqued.values())
for i in range(1, n):
- state_dict = torch.load(filenames[i], map_location=device)["model"]
+ state_dict = torch.load(filenames[i], map_location=device, weights_only=False)[
+ "model"
+ ]
for k in uniqued_names:
avg[k] += state_dict[k]
@@ -199,7 +200,7 @@ def save_checkpoint_with_global_batch_idx(
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
- scaler: Optional[GradScaler] = None,
+ scaler: Optional["GradScaler"] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
@@ -421,8 +422,10 @@ def average_checkpoints_with_averaged_model(
device:
Move checkpoints to this device before averaging.
"""
- state_dict_start = torch.load(filename_start, map_location=device)
- state_dict_end = torch.load(filename_end, map_location=device)
+ state_dict_start = torch.load(
+ filename_start, map_location=device, weights_only=False
+ )
+ state_dict_end = torch.load(filename_end, map_location=device, weights_only=False)
average_period = state_dict_start["average_period"]
diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py
index e5eaba619..d923e8842 100644
--- a/icefall/diagnostics.py
+++ b/icefall/diagnostics.py
@@ -631,7 +631,10 @@ def attach_diagnostics(
)
module.register_forward_hook(forward_hook)
- module.register_backward_hook(backward_hook)
+ if hasattr(module, "register_full_backward_hook"):
+ module.register_full_backward_hook(backward_hook)
+ else:
+ module.register_backward_hook(backward_hook)
if type(module).__name__ in [
"Sigmoid",
@@ -665,7 +668,10 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
module.register_forward_hook(scalar_forward_hook)
- module.register_backward_hook(scalar_backward_hook)
+ if hasattr(module, "register_full_backward_hook"):
+ module.register_full_backward_hook(scalar_backward_hook)
+ else:
+ module.register_backward_hook(scalar_backward_hook)
for name, parameter in model.named_parameters():
diff --git a/icefall/hooks.py b/icefall/hooks.py
index 85583acbe..b543190be 100644
--- a/icefall/hooks.py
+++ b/icefall/hooks.py
@@ -77,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None:
logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
module.register_forward_hook(forward_hook)
- module.register_backward_hook(backward_hook)
+
+ if hasattr(module, "register_full_backward_hook"):
+ module.register_full_backward_hook(backward_hook)
+ else:
+ module.register_backward_hook(backward_hook)
for name, parameter in model.named_parameters():
diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py
index 0178b80bf..023afb5a5 100755
--- a/icefall/rnn_lm/train.py
+++ b/icefall/rnn_lm/train.py
@@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
def get_parser():
@@ -401,7 +407,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
@@ -470,7 +476,7 @@ def train_one_epoch(
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py
index c36abfcdf..acec95e94 100644
--- a/icefall/transformer_lm/train.py
+++ b/icefall/transformer_lm/train.py
@@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
-from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ setup_logger,
+ str2bool,
+ torch_autocast,
+)
def get_parser():
@@ -341,7 +347,7 @@ def compute_validation_loss(
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
@@ -403,7 +409,7 @@ def train_one_epoch(
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
- with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ with torch_autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
diff --git a/icefall/utils.py b/icefall/utils.py
index aab479e56..022f83b3b 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -26,6 +26,7 @@ import pathlib
import random
import re
import subprocess
+import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
@@ -42,6 +43,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
+from packaging import version
from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter
@@ -50,6 +52,48 @@ from icefall.checkpoint import average_checkpoints
Pathlike = Union[str, Path]
+TORCH_VERSION = version.parse(torch.__version__)
+
+
+def create_grad_scaler(device="cuda", **kwargs):
+ """
+ Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
+ Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
+
+ /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning:
+ `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use
+ `torch.amp.GradScaler('cuda', args...)` instead.
+ """
+ if TORCH_VERSION >= version.parse("2.3.0"):
+ from torch.amp import GradScaler
+
+ return GradScaler(device=device, **kwargs)
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ return torch.cuda.amp.GradScaler(**kwargs)
+
+
+@contextmanager
+def torch_autocast(device_type="cuda", **kwargs):
+ """
+ To fix the following warnings:
+ /icefall/egs/librispeech/ASR/zipformer/model.py:323:
+ FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
+ Please use `torch.amp.autocast('cuda', args...)` instead.
+ with torch.cuda.amp.autocast(enabled=False):
+ """
+ if TORCH_VERSION >= version.parse("2.3.0"):
+ # Use new unified API
+ with torch.amp.autocast(device_type=device_type, **kwargs):
+ yield
+ else:
+ # Suppress deprecation warning and use old CUDA-specific autocast
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=FutureWarning)
+ with torch.cuda.amp.autocast(**kwargs):
+ yield
+
# Pytorch issue: https://github.com/pytorch/pytorch/issues/47379
# Fixed: https://github.com/pytorch/pytorch/pull/49853
@@ -186,7 +230,7 @@ class AttributeDict(dict):
tmp = {}
for k, v in self.items():
# PosixPath is ont JSON serializable
- if isinstance(v, pathlib.Path) or isinstance(v, torch.device):
+ if isinstance(v, (pathlib.Path, torch.device, torch.dtype)):
v = str(v)
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
@@ -1551,6 +1595,7 @@ def optim_step_and_measure_param_change(
and the L2 norm of the original parameter. It is given by the formula:
.. math::
+
\begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned}
diff --git a/requirements.txt b/requirements.txt
index d97263142..885bf2fc3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,3 +21,4 @@ flake8==5.0.4
# cantonese word segment support
pycantonese==3.4.0
+packaging