diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index cf0523401..1b6d0026f 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -55,9 +55,9 @@ RUN pip install --no-cache-dir \ "numpy<2.0" \ onnxoptimizer \ onnxsim \ - onnx \ + onnx==1.17.0 \ onnxmltools \ - onnxruntime \ + onnxruntime==1.17.1 \ piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \ pypinyin==0.50.0 \ pytest \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 638e19498..7f36e278d 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -63,23 +63,24 @@ def get_torchaudio_version(torch_version): def get_matrix(min_torch_version, specified_torch_version, specified_python_version): - k2_version = "1.24.4.dev20241029" - kaldifeat_version = "1.25.5.dev20241029" - version = "20241218" + k2_version = "1.24.4.dev20250630" + kaldifeat_version = "1.25.5.dev20250630" + version = "20250630" # torchaudio 2.5.0 does not support python 3.13 - python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] + python_version = ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] torch_version += ["2.0.0", "2.0.1"] - # torch_version += ["2.1.0", "2.1.1", "2.1.2"] - # torch_version += ["2.2.0", "2.2.1", "2.2.2"] + torch_version += ["2.1.0", "2.1.1", "2.1.2"] + torch_version += ["2.2.0", "2.2.1", "2.2.2"] # Test only torch >= 2.3.0 torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] torch_version += ["2.4.1"] torch_version += ["2.5.0"] torch_version += ["2.5.1"] + torch_version += ["2.6.0", "2.7.0", "2.7.1"] if specified_torch_version: torch_version = [specified_torch_version] @@ -109,12 +110,8 @@ def get_matrix(min_torch_version, specified_torch_version, specified_python_vers # torch>=2.5 requires python 3.10 continue - if t == "2.5.1": - k2_version_2 = "1.24.4.dev20241122" - kaldifeat_version_2 = "1.25.5.dev20241126" - else: - k2_version_2 = k2_version - kaldifeat_version_2 = kaldifeat_version + k2_version_2 = k2_version + kaldifeat_version_2 = kaldifeat_version matrix.append( { diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py index 3784d5fa5..e268acf03 100755 --- a/.github/scripts/generate-piper-phonemize-page.py +++ b/.github/scripts/generate-piper-phonemize-page.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -def main(): +def get_v1_2_0_files(): prefix = ( "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/" ) @@ -19,9 +19,70 @@ def main(): "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl", "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", ] + ans = [prefix + f for f in files] + ans.sort() + return ans + + +def get_v1_3_0_files(): + prefix = ( + "https://github.com/csukuangfj/piper-phonemize/releases/download/2025.06.23/" + ) + files = [ + "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp310-cp310-win_amd64.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp311-cp311-win_amd64.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_universal2.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_x86_64.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp312-cp312-win_amd64.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_universal2.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp313-cp313-win_amd64.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp38-cp38-win_amd64.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp39-cp39-win_amd64.whl", + ] + ans = [prefix + f for f in files] + ans.sort() + return ans + + +def main(): + files = get_v1_3_0_files() + get_v1_2_0_files() + with open("piper_phonemize.html", "w") as f: - for file in files: - url = prefix + file + for url in files: + file = url.split("/")[-1] f.write(f'{file}
\n') diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh deleted file mode 100755 index e254419ff..000000000 --- a/.github/scripts/multi-zh-hans.sh +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env bash - -set -ex - -git config --global user.name "k2-fsa" -git config --global user.email "csukuangfj@gmail.com" -git config --global lfs.allowincompletepush true - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "pwd: $PWD" - -cd egs/multi_zh-hans/ASR - -repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -cd exp -git lfs pull --include pretrained.pt -ln -s pretrained.pt epoch-99.pt -cd ../data/lang_bpe_2000 -ls -lh -git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model -git lfs pull --include "*.model" -ls -lh -popd - -log "--------------------------------------------" -log "Export non-streaming ONNX transducer models " -log "--------------------------------------------" -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -ls -lh $repo/exp - -./zipformer/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav \ - $repo/test_wavs/TEST_MEETING_T0000000113.wav \ - $repo/test_wavs/TEST_MEETING_T0000000219.wav \ - $repo/test_wavs/TEST_MEETING_T0000000351.wav - -rm -rf $repo - -repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -cd exp/ -git lfs pull --include pretrained.pt -rm -fv epoch-20.pt -rm -fv *.onnx -ln -s pretrained.pt epoch-20.pt -cd ../data/lang_bpe_2000 -ls -lh -git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model -git lfs pull --include "*.model" -ls -lh -popd - -log "----------------------------------------" -log "Export streaming ONNX CTC models " -log "----------------------------------------" -./zipformer/export-onnx-streaming-ctc.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - --causal 1 \ - --avg 1 \ - --epoch 20 \ - --use-averaged-model 0 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --use-ctc 1 - -ls -lh $repo/exp/ - -log "------------------------------------------------------------" -log "Test exported streaming ONNX CTC models (greedy search) " -log "------------------------------------------------------------" - -test_wavs=( -DEV_T0000000000.wav -DEV_T0000000001.wav -DEV_T0000000002.wav -TEST_MEETING_T0000000113.wav -TEST_MEETING_T0000000219.wav -TEST_MEETING_T0000000351.wav -) - -for w in ${test_wavs[@]}; do - ./zipformer/onnx_pretrained-streaming-ctc.py \ - --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - $repo/test_wavs/$w -done - -log "Upload onnx CTC models to huggingface" -url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 -GIT_LFS_SKIP_SMUDGE=1 git clone $url -dst=$(basename $url) -cp -v $repo/exp/ctc*.onnx $dst -cp -v $repo/data/lang_bpe_2000/tokens.txt $dst -cp -v $repo/data/lang_bpe_2000/bpe.model $dst -mkdir -p $dst/test_wavs -cp -v $repo/test_wavs/*.wav $dst/test_wavs -cd $dst -git lfs track "*.onnx" "bpe.model" -ls -lh -file bpe.model -git status -git add . -git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true - -log "Upload models to https://github.com/k2-fsa/sherpa-onnx" -rm -rf .git -rm -fv .gitattributes -cd .. -tar cjfv $dst.tar.bz2 $dst -ls -lh *.tar.bz2 -mv -v $dst.tar.bz2 ../../../ - -log "----------------------------------------" -log "Export streaming ONNX transducer models " -log "----------------------------------------" - -./zipformer/export-onnx-streaming.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - --causal 1 \ - --avg 1 \ - --epoch 20 \ - --use-averaged-model 0 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --use-ctc 0 - -ls -lh $repo/exp - -log "------------------------------------------------------------" -log "Test exported streaming ONNX transducer models (Python code)" -log "------------------------------------------------------------" - -log "test fp32" -./zipformer/onnx_pretrained-streaming.py \ - --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav - -log "test int8" -./zipformer/onnx_pretrained-streaming.py \ - --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ - --tokens $repo/data/lang_bpe_2000/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav - -log "Upload onnx transducer models to huggingface" - -url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 -GIT_LFS_SKIP_SMUDGE=1 git clone $url -dst=$(basename $url) -cp -v $repo/exp/encoder*.onnx $dst -cp -v $repo/exp/decoder*.onnx $dst -cp -v $repo/exp/joiner*.onnx $dst -cp -v $repo/data/lang_bpe_2000/tokens.txt $dst -cp -v $repo/data/lang_bpe_2000/bpe.model $dst -mkdir -p $dst/test_wavs -cp -v $repo/test_wavs/*.wav $dst/test_wavs -cd $dst -git lfs track "*.onnx" bpe.model -git add . -git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true - -log "Upload models to https://github.com/k2-fsa/sherpa-onnx" -rm -rf .git -rm -fv .gitattributes -cd .. -tar cjfv $dst.tar.bz2 $dst -ls -lh *.tar.bz2 -mv -v $dst.tar.bz2 ../../../ diff --git a/.github/scripts/multi_zh-hans/ASR/run.sh b/.github/scripts/multi_zh-hans/ASR/run.sh new file mode 100755 index 000000000..345b64cf0 --- /dev/null +++ b/.github/scripts/multi_zh-hans/ASR/run.sh @@ -0,0 +1,756 @@ +#!/usr/bin/env bash + +set -ex + +git config --global user.name "k2-fsa" +git config --global user.email "csukuangfj@gmail.com" +git config --global lfs.allowincompletepush true + +python3 -m pip install onnxmltools==1.13.0 onnx==1.17.0 onnxruntime==1.17.1 sherpa-onnx + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/multi_zh-hans/ASR + +log "pwd: $PWD" + +function run_2023_9_2() { + repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + cd exp + git lfs pull --include pretrained.pt + ln -s pretrained.pt epoch-99.pt + cd ../data/lang_bpe_2000 + ls -lh + git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model + git lfs pull --include "*.model" + ls -lh + popd + + log "--------------------------------------------" + log "Export non-streaming ONNX transducer models " + log "--------------------------------------------" + ./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False \ + --fp16 1 + + ls -lh $repo/exp + + ./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav \ + $repo/test_wavs/TEST_MEETING_T0000000113.wav \ + $repo/test_wavs/TEST_MEETING_T0000000219.wav \ + $repo/test_wavs/TEST_MEETING_T0000000351.wav + + ./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav \ + $repo/test_wavs/TEST_MEETING_T0000000113.wav \ + $repo/test_wavs/TEST_MEETING_T0000000219.wav \ + $repo/test_wavs/TEST_MEETING_T0000000351.wav + + ./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.fp16.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.fp16.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.fp16.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav \ + $repo/test_wavs/TEST_MEETING_T0000000113.wav \ + $repo/test_wavs/TEST_MEETING_T0000000219.wav \ + $repo/test_wavs/TEST_MEETING_T0000000351.wav + + rm -rf $repo +} + +function run_2023_11_05_streaming() { + repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + pushd $repo + cd exp/ + git lfs pull --include pretrained.pt + rm -fv epoch-20.pt + rm -fv *.onnx + ln -s pretrained.pt epoch-20.pt + cd ../data/lang_bpe_2000 + ls -lh + git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model + git lfs pull --include "*.model" + ls -lh + popd + + log "----------------------------------------" + log "Export streaming ONNX CTC models " + log "----------------------------------------" + ./zipformer/export-onnx-streaming-ctc.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 \ + --fp16 1 + + ls -lh $repo/exp/ + + log "------------------------------------------------------------" + log "Test exported streaming ONNX CTC models (greedy search) " + log "------------------------------------------------------------" + + test_wavs=( + DEV_T0000000000.wav + DEV_T0000000001.wav + DEV_T0000000002.wav + TEST_MEETING_T0000000113.wav + TEST_MEETING_T0000000219.wav + TEST_MEETING_T0000000351.wav + ) + + for w in ${test_wavs[@]}; do + log "----fp32----" + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w + + log "----int8----" + + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w + + log "----fp16----" + + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w + done + + log "Upload onnx CTC models to huggingface" + name=( + sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-int8-2023-12-13 + sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-fp16-2023-12-13 + ) + for n in ${name[@]}; do + url=https://huggingface.co/k2-fsa/$n + GIT_LFS_SKIP_SMUDGE=1 git clone $url + dst=$(basename $url) + if [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]]; then + cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.onnx $dst + elif [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-int8-2023-12-13 ]]; then + cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst + elif [[ $n == sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-fp16-2023-12-13 ]]; then + cp -v $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst + fi + + cp -v $repo/data/lang_bpe_2000/tokens.txt $dst + cp -v $repo/data/lang_bpe_2000/bpe.model $dst + mkdir -p $dst/test_wavs + cp -v $repo/test_wavs/*.wav $dst/test_wavs + cd $dst + git lfs track "*.onnx" "bpe.model" "*.wav" + ls -lh + file bpe.model + git status + git add . + git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + + log "Upload models to https://github.com/k2-fsa/sherpa-onnx" + rm -rf .git + rm -fv .gitattributes + cd .. + tar cjfv $dst.tar.bz2 $dst + ls -lh *.tar.bz2 + mv -v $dst.tar.bz2 ../../../ + done + + log "----------------------------------------" + log "Export streaming ONNX transducer models " + log "----------------------------------------" + + ./zipformer/export-onnx-streaming.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 0 \ + --fp16 1 + + ls -lh $repo/exp + + log "------------------------------------------------------------" + log "Test exported streaming ONNX transducer models (Python code)" + log "------------------------------------------------------------" + + log "test fp32" + ./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + + log "test int8" + ./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + + log "test fp16" + ./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.fp16.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + + name=( + sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-13 + sherpa-onnx-streaming-zipformer-multi-zh-hans-int8-2023-12-13 + sherpa-onnx-streaming-zipformer-multi-zh-hans-fp16-2023-12-13 + ) + + for n in ${name[@]}; do + url=https://huggingface.co/csukuangfj/$n + GIT_LFS_SKIP_SMUDGE=1 git clone $url + dst=$(basename $url) + if [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-13 ]]; then + cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst + cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst + cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx $dst + elif [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-int8-2023-12-13 ]]; then + cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst + cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx $dst + cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx $dst + elif [[ $n == sherpa-onnx-streaming-zipformer-multi-zh-hans-fp16-2023-12-13 ]]; then + cp -v $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst + cp -v $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst + cp -v $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.fp16.onnx $dst + fi + + cp -v $repo/data/lang_bpe_2000/tokens.txt $dst + cp -v $repo/data/lang_bpe_2000/bpe.model $dst + mkdir -p $dst/test_wavs + cp -v $repo/test_wavs/*.wav $dst/test_wavs + cd $dst + git lfs track "*.onnx" "bpe.model" "*.wav" + ls -lh + file bpe.model + git status + git add . + git commit -m "upload model" && git push https://csukuangfj:${HF_TOKEN}@huggingface.co/csukuangfj/$dst main || true + + log "Upload models to https://github.com/k2-fsa/sherpa-onnx" + rm -rf .git + rm -fv .gitattributes + cd .. + tar cjfv $dst.tar.bz2 $dst + ls -lh *.tar.bz2 + mv -v $dst.tar.bz2 ../../../ + done +} + +function run_2023_12_12_streaming() { + log "Upload onnx transducer models to huggingface" + + url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 + GIT_LFS_SKIP_SMUDGE=1 git clone $url + dst=$(basename $url) + cp -v $repo/exp/encoder*.onnx $dst + cp -v $repo/exp/decoder*.onnx $dst + cp -v $repo/exp/joiner*.onnx $dst + cp -v $repo/data/lang_bpe_2000/tokens.txt $dst + cp -v $repo/data/lang_bpe_2000/bpe.model $dst + mkdir -p $dst/test_wavs + cp -v $repo/test_wavs/*.wav $dst/test_wavs + cd $dst + git lfs track "*.onnx" bpe.model "*.wav" + git add . + git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + + log "Upload models to https://github.com/k2-fsa/sherpa-onnx" + rm -rf .git + rm -fv .gitattributes + cd .. + tar cjfv $dst.tar.bz2 $dst + ls -lh *.tar.bz2 + mv -v $dst.tar.bz2 ../../../ +} + +function run_yuekai_large() { + repo_url=https://csukuangfj:${HF_TOKEN}@huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + git lfs pull --include pretrained.pt + mv pretrained.pt epoch-99.pt + curl -SL -O https://huggingface.co/pingzxy/icefall-asr-multi-zh-hans-zipformer-large-onnx/resolve/main/tokens.txt + popd + + log "----------------------------------------" + log "Export streaming ONNX CTC models " + log "----------------------------------------" + ./zipformer/export-onnx-streaming-ctc.py \ + --exp-dir $repo/ \ + --tokens $repo/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 99 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 \ + \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 768,1024,1536,2048,1536,768 \ + --encoder-dim 256,384,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + \ + --fp16 1 \ + --use-whisper-features 1 + + + ls -lh $repo/ + pushd $repo + +cat >README.md <README.md < DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train {self.args.subset} cuts") - if self.args.subset == "XL": - filenames = glob.glob( - f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" - ) - pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" - ) - - cuts_train = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) - else: - path = ( - self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" - ) - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) - - @lru_cache() - def fsc_train_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz" - ) - - @lru_cache() - def fsc_valid_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands valid cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def fsc_test_small_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands small test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz" - ) - - @lru_cache() - def fsc_test_large_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands large test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz" - ) diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py new file mode 120000 index 000000000..75bc3d45a --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../../ASR/zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 39d8fc6cd..2d88b6e55 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -961,7 +962,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index bf50bf5ea..63a38a4cc 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -805,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 485ea69c9..406749f22 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -92,6 +92,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -942,7 +943,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7e0bf5b7b..fc866f83b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -667,7 +667,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -707,7 +709,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 38b60fcb9..5b3a021ad 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -271,7 +271,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -351,7 +351,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -362,7 +364,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 0b271a51c..349e8f02d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -774,7 +774,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -814,7 +816,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index c4a13b101..14c132ada 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -420,7 +421,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -629,7 +630,7 @@ def train_one_epoch( scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -676,7 +677,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -965,7 +966,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index e6327bb5e..cf58fd18d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -868,7 +868,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -907,7 +909,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 19b26361e..f8e3fa43b 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -334,7 +334,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -345,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index a0cdfcf03..e528b2cb8 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -290,7 +290,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -386,7 +386,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -397,7 +399,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index a2f1125ca..64e77f421 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -493,7 +494,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -694,7 +695,7 @@ def train_one_epoch( graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -743,7 +744,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1005,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 74f6e73fa..01fcf0685 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -574,7 +574,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + torch.load( + f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False + ) ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -609,7 +611,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index ca21bd6bf..fc33f9512 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1002,7 +1008,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..b00cc6cc6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1001,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index d19d50ae6..ec39d5b36 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 709b14070..bd25cfa29 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: An FSA representing LG. """ lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 856c9d945..8c75eb871 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -750,7 +750,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e7bad7ed8..9f148b348 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 42c3a5d7f..f29d1d9db 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index feb81d500..e23da3b56 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -521,7 +522,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -717,7 +718,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -763,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1023,7 +1024,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1a724830b..cfbbb334c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -935,7 +935,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..5aafe10af 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index dcff088e2..888f9931e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -241,7 +241,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4fc4fa7f8..1b31b5485 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -90,9 +89,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +561,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -772,7 +773,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -848,7 +849,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1176,7 +1177,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a2b4f9e1a..e25b79e2e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -815,7 +815,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index e39637bd8..619e783b0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -239,7 +239,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 2c1cef3a3..e169b499f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -551,7 +552,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -747,7 +748,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -793,7 +794,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1068,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..3b6ce9b89 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -21,7 +21,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 5b595c76c..5850555cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,9 +10,11 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined +from icefall.utils import create_grad_scaler, torch_autocast + # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. @@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler(enabled=True) + scaler = create_grad_scaler(enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.cuda.amp.autocast(enabled=True): + with torch_autocast(enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 931341cc4..0611fd8cb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + create_grad_scaler, + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -453,7 +459,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -608,7 +614,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -650,7 +656,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -868,7 +874,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2b872f1d5..2af8f3f4c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -68,7 +67,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) def add_model_arguments(parser: argparse.ArgumentParser): @@ -496,7 +502,7 @@ def save_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, and training stats to file. @@ -650,7 +656,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -693,7 +699,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -939,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3c4500087..6d1da7440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -741,7 +741,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 66c84b2a9..f1d16749c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1347,7 +1347,10 @@ def modified_beam_search( ( context_score, new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + _, + ) = context_graph.forward_one_step( + hyp.context_state, new_token, strict_mode=False + ) new_log_prob = topk_log_probs[k] + context_score @@ -2853,7 +2856,10 @@ def modified_beam_search_LODR( ( context_score, new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + _, + ) = context_graph.forward_one_step( + hyp.context_state, new_token, strict_mode=False + ) ys.append(new_token) state_cost = hyp.state_cost.forward_one_step(new_token) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index c57514193..5a4a74ebb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -754,7 +754,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..6a69332aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 6923f4d40..e6ddcab25 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -265,7 +265,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 6c19f2cb0..ce6c89614 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -91,9 +90,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -523,7 +524,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -716,7 +717,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -759,7 +760,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1000,7 +1001,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 7c62bfa58..18a3792b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -921,7 +921,7 @@ def load_ngram_LM( if pt_file.is_file(): logging.info(f"Loading pre-compiled {pt_file}") - d = torch.load(pt_file, map_location=device) + d = torch.load(pt_file, map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) @@ -1101,7 +1101,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale elif params.decoding_method in [ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..fbc4db921 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 05e6a6fba..19143fb5d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -274,7 +274,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index fdafa5a87..50670d1b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -87,9 +86,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -546,7 +547,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -755,7 +756,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -827,7 +828,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1126,7 +1127,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5195a4ef6..925c01c7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -913,7 +913,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 875b03f7f..c35f52309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -548,7 +549,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -744,7 +745,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -789,7 +790,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,7 +1048,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 7a3e63218..404d7a3d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -972,7 +972,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index a9ce75a7b..9e2669379 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 66dc5f991..6f9f92623 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -571,7 +572,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -768,7 +769,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -814,7 +815,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..a5d2457f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f033cb9a..35ee74f15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -519,7 +520,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -736,7 +737,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -781,7 +782,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1039,7 +1040,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 3bca7db2c..4f3fbaa81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -348,7 +348,9 @@ class CodebookIndexExtractor: num_codebooks=self.params.num_codebooks, codebook_size=256, ) - quantizer.load_state_dict(torch.load(self.quantizer_file_path)) + quantizer.load_state_dict( + torch.load(self.quantizer_file_path, weights_only=False) + ) quantizer.to(self.params.device) return quantizer diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py index 27ef0a244..949a497ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -289,7 +289,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index eb8841cc4..048de7bb9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -910,7 +910,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index 7095c3cc8..da1bf17fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -813,7 +813,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index e7546ec45..d3d996b4a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -635,7 +636,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -678,7 +679,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -857,7 +858,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -903,7 +904,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1220,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index add0e6a18..ed990b689 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 4bf11ac24..fabda3aaa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 30a737061..5a317083c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -28,6 +28,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding as ScaledEmbedding +from icefall.utils import torch_autocast + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod @@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..f94da9788 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, symlink_or_copy, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -763,7 +764,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cbde2a2e4..ee05627ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import is_jit_tracing, make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast class Zipformer(EncoderInterface): @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 629bec058..3b181bf23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -633,7 +633,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -672,7 +674,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 7641fa5af..9e16c3fd7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index d1b7eec65..f7dd07f8d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index a6e919e2f..f1ab2a3ec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 323ba2642..a13952dfa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 1e638aa7d..32242c94e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..a26f11c82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -588,7 +589,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -787,7 +788,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -833,7 +834,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1128,7 +1129,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index fa7144f0f..3af3ada2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -624,7 +624,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -663,7 +665,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index e2f08abc6..233f00236 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -808,7 +808,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index e497787d3..025b146b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 80604ef4a..70d9841bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 0582b289f..bf0faf9f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, torch_autocast class Transducer(nn.Module): @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index a82f3562b..9ceec5f5a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index b98756a54..431760f9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -362,7 +362,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -373,7 +375,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..5585d74de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,9 +81,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -778,7 +779,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -822,7 +823,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1118,7 +1119,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 02029c108..aa2fe8e38 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -936,7 +936,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index aa2dd17fb..f98851f50 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..4d8a2644d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,7 +81,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -597,7 +603,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -764,7 +770,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -810,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1124,7 +1130,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index c7e45564f..640d72b67 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py index 35158ced4..61c1a9663 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -768,7 +768,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py index a4f52ad7f..e95bb3357 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -788,7 +788,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index da5e144c9..4b97575e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,7 +85,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -615,7 +621,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -795,7 +801,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1218,7 +1224,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e07777c9f..3cad83a0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -747,7 +747,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..e06594c27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -24,7 +24,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index c29b8d8c9..693db2beb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..ad14ec9dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -91,7 +90,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -608,7 +614,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -790,7 +796,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1225,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 1b52aa8b5..283252a46 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -311,8 +311,7 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 92529e06c..db12ab827 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -398,7 +398,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -428,7 +430,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index b3dfab64a..4ad7cb016 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -167,13 +167,15 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index cda03b56e..ec700626a 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -589,7 +589,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -628,7 +630,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index cc4471e2b..1b329e8f3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -663,7 +663,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 92dea3aa1..4b234a328 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 5c6956324..9714aa537 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -249,7 +249,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index 7698ada79..a2ea1dd06 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 1bfd071de..368bd20fa 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -72,9 +71,11 @@ from icefall.lexicon import UniqLexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -550,7 +551,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -757,7 +758,7 @@ def train_one_epoch( phone_lexicon: UniqLexicon, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1093,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4d9bbf4b1..06b1c05b9 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -222,7 +222,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index fe9347b95..e17407c5f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -947,7 +947,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -987,7 +989,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method in [ diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cbfb3728e..6462d22f8 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1013,7 +1013,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py index 3cda337c0..a4da83949 100755 --- a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py @@ -1049,7 +1049,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py index 99685f2fe..413b5bb1e 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -153,6 +153,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -176,6 +183,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + class OnnxModel(nn.Module): """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" @@ -407,7 +423,7 @@ def main(): opset_version = 13 logging.info("Exporting ctc model") - filename = params.exp_dir / f"model.onnx" + filename = params.exp_dir / "model.onnx" export_ctc_model_onnx( model, filename, @@ -420,7 +436,7 @@ def main(): logging.info("Generate int8 quantization models") - filename_int8 = params.exp_dir / f"model.int8.onnx" + filename_int8 = params.exp_dir / "model.int8.onnx" quantize_dynamic( model_input=filename, model_output=filename_int8, @@ -428,6 +444,10 @@ def main(): weight_type=QuantType.QInt8, ) + if params.fp16: + filename_fp16 = params.exp_dir / "model.fp16.onnx" + export_onnx_fp16(filename, filename_fp16) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py index b1f9d70ff..9a715eefd 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -150,12 +150,35 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--use-whisper-features", + type=str2bool, + default=False, + help="True to use whisper features. Must match the one used in training", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + parser.add_argument( + "--use-external-data", + type=str2bool, + default=False, + help="Set it to true for model file size > 2GB", + ) + add_model_arguments(parser) return parser -def add_meta_data(filename: str, meta_data: Dict[str, str]): +def add_meta_data( + filename: str, meta_data: Dict[str, str], use_external_data: bool = False +): """Add meta data to an ONNX model. It is changed in-place. Args: @@ -164,13 +187,46 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): meta_data: Key-value pairs. """ + filename = str(filename) + model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = value - onnx.save(model, filename) + if use_external_data: + # For models file size > 2GB + external_filename = Path(filename).stem + + onnx.save( + model, + filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_filename + ".weights", + ) + else: + onnx.save(model, filename) + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path + + onnx_fp16_model = convert_float_to_float16_model_path( + onnx_fp32_path, keep_io_types=True + ) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) class OnnxModel(nn.Module): @@ -285,6 +341,8 @@ def export_streaming_ctc_model_onnx( encoder_filename: str, opset_version: int = 11, dynamic_batch: bool = True, + use_whisper_features: bool = False, + use_external_data: bool = False, ) -> None: model.encoder.__class__.forward = model.encoder.__class__.streaming_forward @@ -382,6 +440,10 @@ def export_streaming_ctc_model_onnx( "value_head_dims": value_head_dims, "num_heads": num_heads, } + + if use_whisper_features: + meta_data["feature"] = "whisper" + logging.info(f"meta_data: {meta_data}") for i in range(len(init_state[:-2]) // 6): @@ -428,7 +490,11 @@ def export_streaming_ctc_model_onnx( else {}, ) - add_meta_data(filename=encoder_filename, meta_data=meta_data) + add_meta_data( + filename=encoder_filename, + meta_data=meta_data, + use_external_data=use_external_data, + ) @torch.no_grad() @@ -559,12 +625,19 @@ def main(): opset_version = 13 logging.info("Exporting model") - model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + + if params.use_external_data: + model_filename = f"ctc-{suffix}.onnx" + else: + model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + export_streaming_ctc_model_onnx( model, - model_filename, + str(model_filename), opset_version=opset_version, dynamic_batch=params.dynamic_batch == 1, + use_whisper_features=params.use_whisper_features, + use_external_data=params.use_external_data, ) logging.info(f"Exported model to {model_filename}") @@ -574,13 +647,25 @@ def main(): logging.info("Generate int8 quantization models") + if params.use_external_data: + model_filename_int8 = f"ctc-{suffix}.int8.onnx" + else: model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + if params.fp16: + if params.use_external_data: + model_filename_fp16 = f"ctc-{suffix}.fp16.onnx" + export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) + else: + model_filename_fp16 = params.exp_dir / f"ctc-{suffix}.fp16.onnx" + export_onnx_fp16(model_filename, model_filename_fp16) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 43ec5d59b..daeb86f6a 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -176,12 +176,47 @@ def get_parser(): help="Whether to export models in fp16", ) + parser.add_argument( + "--use-whisper-features", + type=str2bool, + default=False, + help="True to use whisper features. Must match the one used in training", + ) + + parser.add_argument( + "--use-external-data", + type=str2bool, + default=False, + help="Set it to true for model file size > 2GB", + ) + add_model_arguments(parser) return parser -def add_meta_data(filename: str, meta_data: Dict[str, str]): +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path + + onnx_fp16_model = convert_float_to_float16_model_path( + onnx_fp32_path, keep_io_types=True + ) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def add_meta_data( + filename: str, meta_data: Dict[str, str], use_external_data: bool = False +): """Add meta data to an ONNX model. It is changed in-place. Args: @@ -196,7 +231,19 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): meta.key = key meta.value = value - onnx.save(model, filename) + if use_external_data: + # For models file size > 2GB + external_filename = Path(filename).stem + + onnx.save( + model, + filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_filename + ".weights", + ) + else: + onnx.save(model, filename) class OnnxEncoder(nn.Module): @@ -357,6 +404,8 @@ def export_encoder_model_onnx( opset_version: int = 11, feature_dim: int = 80, dynamic_batch: bool = True, + use_whisper_features: bool = False, + use_external_data: bool = False, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -456,6 +505,9 @@ def export_encoder_model_onnx( "value_head_dims": value_head_dims, "num_heads": num_heads, } + if use_whisper_features: + meta_data["feature"] = "whisper" + logging.info(f"meta_data: {meta_data}") for i in range(len(init_state[:-2]) // 6): @@ -502,7 +554,11 @@ def export_encoder_model_onnx( else {}, ) - add_meta_data(filename=encoder_filename, meta_data=meta_data) + add_meta_data( + filename=encoder_filename, + meta_data=meta_data, + use_external_data=use_external_data, + ) def export_decoder_model_onnx( @@ -751,13 +807,18 @@ def main(): opset_version = 13 logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + if params.use_external_data: + encoder_filename = f"encoder-{suffix}.onnx" + else: + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" export_encoder_model_onnx( encoder, - encoder_filename, + str(encoder_filename), opset_version=opset_version, feature_dim=params.feature_dim, dynamic_batch=params.dynamic_batch == 1, + use_whisper_features=params.use_whisper_features, + use_external_data=params.use_external_data, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -782,24 +843,20 @@ def main(): logging.info(f"Exported joiner to {joiner_filename}") if params.fp16: - from onnxconverter_common import float16 - logging.info("Generate fp16 models") - encoder = onnx.load(encoder_filename) - encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) - encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16, encoder_filename_fp16) + if params.use_external_data: + encoder_filename_fp16 = f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16_large_2gb(encoder_filename, encoder_filename_fp16) + else: + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) - decoder = onnx.load(decoder_filename) - decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16, decoder_filename_fp16) + export_onnx_fp16(decoder_filename, decoder_filename_fp16) - joiner = onnx.load(joiner_filename) - joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16, joiner_filename_fp16) + export_onnx_fp16(joiner_filename, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection @@ -807,7 +864,11 @@ def main(): if params.enable_int8_quantization: logging.info("Generate int8 quantization models") - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + if params.use_external_data: + encoder_filename_int8 = f"encoder-{suffix}.int8.onnx" + else: + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( model_input=encoder_filename, model_output=encoder_filename_int8, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index a56a7a3e6..03c7d6f82 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -70,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -182,6 +181,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" @@ -595,20 +603,14 @@ def main(): if params.fp16: logging.info("Generate fp16 models") - encoder = onnx.load(encoder_filename) - encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16, encoder_filename_fp16) + export_onnx_fp16(encoder_filename, encoder_filename_fp16) - decoder = onnx.load(decoder_filename) - decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16, decoder_filename_fp16) + export_onnx_fp16(decoder_filename, decoder_filename_fp16) - joiner = onnx.load(joiner_filename) - joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16, joiner_filename_fp16) + export_onnx_fp16(joiner_filename, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2ff631914..94e8b273a 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -95,11 +94,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + create_grad_scaler, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -140,8 +141,8 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): type=str2bool, default=False, help=""" - Whether to adapt. If true, we will mix 5% of the new data - with 95% of the original data to fine-tune. This is useful + Whether to adapt. If true, we will mix 5%% of the new data + with 95%% of the original data to fine-tune. This is useful if you want to maintain the performance on the original domain """, ) @@ -765,7 +766,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -808,7 +809,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -985,7 +986,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1049,7 +1050,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1134,7 +1135,7 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr: .2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -1373,7 +1374,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index fcd07ae34..d1978df52 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -346,7 +346,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -357,7 +359,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index c7dbe1e0a..6ef250819 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast class AsrModel(nn.Module): @@ -210,10 +210,10 @@ class AsrModel(nn.Module): ) # Compute consistency regularization loss - exchanged_targets = ctc_output.detach().chunk(2, dim=0) - exchanged_targets = torch.cat( - [exchanged_targets[1], exchanged_targets[0]], dim=0 - ) # exchange: [x1, x2] -> [x2, x1] + batch_size = ctc_output.shape[0] + assert batch_size % 2 == 0, batch_size + # exchange: [x1, x2] -> [x2, x1] + exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0) cr_loss = nn.functional.kl_div( input=ctc_output, target=exchanged_targets, @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 9f3571b08..65ea7c7f2 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -289,7 +289,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 4341ef61f..90a6ff5b8 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -305,7 +305,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -389,7 +389,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -400,7 +402,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c2931..22aa1b1ca 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -26,6 +26,8 @@ import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -160,8 +162,10 @@ class PiecewiseLinear(object): extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] + + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( PiecewiseLinear(*zip(x_vals, y_vals1)), PiecewiseLinear(*zip(x_vals, y_vals2)), @@ -306,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -759,7 +763,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1014,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1353,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1401,9 +1405,9 @@ class SwooshL(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: - return k2.swoosh_l_forward(x) + return k2.swoosh_l_forward(x).to(x.dtype) else: - return k2.swoosh_l(x) + return k2.swoosh_l(x).to(x.dtype) # return SwooshLFunction.apply(x) @@ -1430,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1475,9 +1479,9 @@ class SwooshR(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: - return k2.swoosh_r_forward(x) + return k2.swoosh_r_forward(x).to(x.dtype) else: - return k2.swoosh_r(x) + return k2.swoosh_r(x).to(x.dtype) # return SwooshRFunction.apply(x) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index f8864d58b..42ae9b9f2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -98,9 +97,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -829,7 +830,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -1034,7 +1035,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1101,9 +1102,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1449,7 +1448,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1551,9 +1550,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae0129..e83a89400 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -47,6 +47,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index 91533be8d..e8798aed6 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -1005,7 +1005,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index bbc582f50..66c401761 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 3511590da..fcd7272e9 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -762,7 +763,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -805,7 +806,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -982,7 +983,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1052,7 +1053,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1397,7 +1398,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 8e2dfdd72..8bc163db5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -50,6 +50,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py index 7f605e2c8..b9eed099c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/decode.py +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -679,7 +679,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -719,7 +721,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60112a84e..bd3bfa332 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -65,7 +64,14 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] @@ -533,7 +539,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -687,7 +693,7 @@ def train_one_epoch( graph_compiler: BpeCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -726,7 +732,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -987,7 +993,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py index 4d93a905f..acc814a00 100755 --- a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 3f36f229f..c26a2f5cc 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -96,9 +95,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -775,7 +776,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -818,7 +819,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -995,7 +996,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1065,7 +1066,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1407,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 8d7aa8027..1347570df 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -27,6 +27,8 @@ import torch.nn.functional as F from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 9ab214e86..2b83d58ef 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -94,9 +93,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -707,7 +708,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -883,7 +884,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -947,7 +948,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1253,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 43865609a..b84b1c32a 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -49,6 +49,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 33c0bf199..bd3ce21f5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -569,7 +569,9 @@ def main(): if params.decoding_method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() @@ -602,7 +604,11 @@ def main(): torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt") else: logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", + map_location=device, + weights_only=False, + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index 6990c90a0..d5667cafa 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -308,7 +308,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -317,7 +319,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 1e7afc777..ca860b877 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -269,7 +269,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -331,7 +331,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -340,7 +342,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1785a328..e0ca0a6a5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -87,9 +86,11 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -696,7 +697,7 @@ def train_one_epoch( mmi_graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -744,7 +745,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1037,7 +1038,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 17daa3c9d..0080513f3 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -86,6 +86,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 2723cc770..1ff2b03c0 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index f183d90fd..240cd2c0d 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 94948695d..12f95c16f 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index c907b41c5..5bebf60f0 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -1115,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 937fb382e..d772f56d0 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -78,6 +78,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -944,7 +945,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357..5071a91a8 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -22,6 +22,7 @@ import math import random import warnings from typing import List, Optional, Tuple, Union +from icefall.utils import torch_autocast import torch from encoder_interface import EncoderInterface @@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 82c68803f..19cce1708 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -84,6 +84,7 @@ from icefall.utils import ( get_texts, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -757,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index b276d0587..a32183bf7 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -79,6 +79,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions_otc, @@ -758,7 +759,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 296f9a4f4..906025b7f 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -73,6 +73,8 @@ def compute_fbank_ljspeech(num_jobs: int): f_min=0, f_max=8000, ) + if not torch.cuda.is_available(): + config.device = "cpu" prefix = "ljspeech" suffix = "jsonl.gz" diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py index c15db11f7..1b53465c0 100755 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -328,9 +328,14 @@ def main(): logging.info(msg) def token_ids_to_words(token_ids: List[int]) -> str: - text = "" + byte_list = [] for i in token_ids: - text += token_table[i] + token = token_table[i] + if token.startswith("<0x") and token.endswith(">"): + byte_list.append(int(token[3:-1], 16)) + else: + byte_list += list(token.encode("utf-8")) + text = bytes(byte_list).decode("utf-8", errors='ignore') return text.replace("▁", " ").strip() if params.method == "fast_beam_search": diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index 830c70397..42dce80c5 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -42,8 +42,8 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel # For multi-hans fine-tuned whisper model # huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct +# huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-cli download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct # First, we only train the projector and freeze other modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ @@ -55,9 +55,10 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora False --unfreeze-llm False + --use-lora False \ + --unfreeze-llm False -# Then we jointly train the projector and LLM LoRA modules. +# Then, we jointly train the projector and LLM LoRA modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ --exp-dir ./whisper_llm_zh/exp_test \ @@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora True --unfreeze-llm True + --use-lora True \ + --unfreeze-llm True \ --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt ``` @@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B # For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt # For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt @@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \ --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --use-lora True --dataset aishell + --use-lora True \ + --dataset aishell ``` diff --git a/egs/speech_llm/ASR_LLM/prepare.sh b/egs/speech_llm/ASR_LLM/prepare.sh old mode 100644 new mode 100755 index 6f5ed5448..8ca3c1c36 --- a/egs/speech_llm/ASR_LLM/prepare.sh +++ b/egs/speech_llm/ASR_LLM/prepare.sh @@ -7,6 +7,9 @@ set -eou pipefail stage=0 stop_stage=0 + +. shared/parse_options.sh || exit 1 + # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -23,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # pip install huggingface_hub['cli'] # for aishell 1 - huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data yuekai/aishell_whisper_fbank_lhotse fi @@ -31,9 +34,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" # for multi-hans-zh - huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse - huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse - huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -41,6 +44,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # for speechio test sets mkdir data_speechio - huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio + huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio mv data_speechio/fbank/* data/fbank fi diff --git a/egs/speech_llm/ASR_LLM/shared b/egs/speech_llm/ASR_LLM/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/speech_llm/ASR_LLM/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 882ce4fbf..3036b471e 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -357,43 +357,6 @@ def decode_dataset( Returns: Return a dict, whose key may be "beam-search". """ - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - results = [] num_cuts = 0 @@ -406,6 +369,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -418,12 +382,8 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - ref_words = ref_text.split() - print(f"ref: {ref_text}") - print(f"hyp: {''.join(hyp_words)}") - this_batch.append((cut_id, ref_words, hyp_words)) + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_text)) results[lm_scale].extend(this_batch) @@ -439,40 +399,38 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - - enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recog_path, texts=results, char_level=True) + logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned + # The following prints out CERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -495,9 +453,13 @@ def main(): params = get_params() params.update(vars(args)) + + params.res_dir = params.exp_dir / f"{params.method}" + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + params.res_dir + / f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}" ) logging.info("Decoding started") @@ -574,23 +536,20 @@ def main(): if params.avg > 1: start = params.epoch - params.avg + 1 assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - assert "model" not in checkpoint # deepspeed converted checkpoint only contains model state_dict filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" + f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin" for epoch in range(start, params.epoch + 1) ] avg_checkpoint = average_checkpoints(filenames) model.load_state_dict(avg_checkpoint, strict=False) - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(avg_checkpoint, filename) + # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", + map_location="cpu", ) model.load_state_dict(checkpoint, strict=False) @@ -643,8 +602,7 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5f224c984..7947a60a5 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # 2024 Yuekai Zhang +# 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ """ import argparse -import copy import logging import os -import random import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import deepspeed -import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn import transformers import whisper from asr_datamodule import AsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall import diagnostics from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool DEFAULT_SPEECH_TOKEN = "" @@ -286,13 +272,6 @@ def compute_loss( Returns: Return a tuple of two elements. The first element is the loss tensor. """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. def preprocess( messages, @@ -347,46 +326,6 @@ def compute_loss( return input_ids, attention_mask, target_ids - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = next(model.parameters()).device feature = batch["inputs"] @@ -397,11 +336,10 @@ def compute_loss( batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] messages = [] for i, text in enumerate(texts): + text = text.replace(" ", "") message = [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text}, @@ -516,14 +454,17 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - model.encoder_projector.train() + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -533,6 +474,9 @@ def train_one_epoch( world_size=world_size, ) model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -648,7 +592,6 @@ def run(rank, world_size, args): speech_encoder_dim = whisper_model.dims.n_audio_state for name, param in speech_encoder.named_parameters(): param.requires_grad = False - speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: @@ -671,7 +614,6 @@ def run(rank, world_size, args): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - llm.eval() else: if params.use_lora: lora_config = LoraConfig( @@ -728,7 +670,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - assert params.deepspeed and world_size > 1 + assert params.deepspeed logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( args=params, model=model, model_parameters=model.parameters() @@ -764,7 +706,7 @@ def run(rank, world_size, args): if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict["max_duration"] = params.max_duration - # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) @@ -806,15 +748,15 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", client_state={}, exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint @@ -824,7 +766,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") logging.info("Done!") @@ -865,6 +807,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) run(rank=rank, world_size=world_size, args=args) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index 170f37767..2455f3630 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -422,7 +422,7 @@ def compute_loss( texts = batch["supervisions"]["text"] unk_id = params.unk_id - y = convert_texts_into_ids(texts, unk_id, sp=sp) + y = convert_texts_into_ids(texts, sp=sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 6fed32e81..c6fa34e70 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -397,7 +397,7 @@ def compute_loss( texts = batch["supervisions"]["text"] unk_id = params.unk_id - y = convert_texts_into_ids(texts, unk_id, sp=sp) + y = convert_texts_into_ids(texts, sp=sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 8b35187b1..2cc06df71 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -292,8 +292,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 8472b8531..0af7c1595 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 2: Finetune the model" + log "Stage 3: Finetune the model" # The following configuration of lr schedule should work well # You may also tune the following parameters to adjust learning rate schedule @@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 1: Decode the finetuned model." + log "Stage 4: Decode the finetuned model." export CUDA_VISIBLE_DEVICES="0" for t in small large; do python ./zipformer/decode.py \ @@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 2: Export the finetuned model." + log "Stage 5: Export the finetuned model." python ./zipformer/export.py \ --epoch 10 \ diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 340a41231..a628c7e58 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -35,7 +35,6 @@ from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params from icefall import ContextGraph -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index d19172b38..b1abfd79e 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -69,6 +69,7 @@ import argparse import copy import logging import warnings +from functools import partial from pathlib import Path from typing import List, Optional, Tuple, Union @@ -90,6 +91,7 @@ from train import ( add_training_arguments, compute_validation_loss, display_and_save_batch, + encode_text, get_adjusted_batch_count, get_model, get_params, @@ -100,7 +102,6 @@ from train import ( ) from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -110,11 +111,11 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -254,7 +255,6 @@ def load_model_params( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -289,7 +289,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): @@ -347,7 +347,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -418,7 +417,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -436,7 +434,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -523,7 +521,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -576,14 +573,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -601,6 +594,9 @@ def run(rank, world_size, args): if params.continue_finetune: assert params.start_epoch > 0, params.start_epoch + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) @@ -666,17 +662,10 @@ def run(rank, world_size, args): else: train_cuts = wenetspeech.nihaowenwen_train_cuts() - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -691,7 +680,7 @@ def run(rank, world_size, args): valid_cuts = wenetspeech.nihaowenwen_dev_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -699,7 +688,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -724,7 +712,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -760,6 +747,8 @@ def main(): WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 40960c2ae..5d9d8de36 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -53,6 +53,7 @@ import argparse import copy import logging import warnings +from functools import partial from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -776,7 +776,6 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -811,7 +810,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): @@ -859,7 +858,6 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -872,7 +870,6 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=False, ) @@ -895,7 +892,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -971,7 +967,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -988,7 +983,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -1077,7 +1072,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -1098,6 +1092,20 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict): + text = c.supervisions[0].text + tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ids = [] + for t in tokens: + if t in token_table: + ids.append(token_table[t]) + else: + logging.warning(f"Text : {text} has OOV token : {t} , encode to ") + ids.append(token_table[""]) + c.supervisions[0].tokens = ids + return c + + def run(rank, world_size, args): """ Args: @@ -1130,14 +1138,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -1216,17 +1220,10 @@ def run(rank, world_size, args): return True - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1240,7 +1237,7 @@ def run(rank, world_size, args): ) valid_cuts = wenetspeech.valid_cuts() - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -1248,7 +1245,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -1273,7 +1269,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1307,7 +1302,6 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1317,8 +1311,6 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. """ from lhotse.utils import uuid4 @@ -1332,8 +1324,8 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) + tokens = [c.supervisions[0].tokens for c in supervisions["cut"]] + num_tokens = sum(len(i) for i in tokens) logging.info(f"num tokens: {num_tokens}") @@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -1385,6 +1375,7 @@ def main(): args = parser.parse_args() args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index e0a94bf08..3de7136ec 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) logging.info("Loading G.fst.txt") with open("data/lm/G.fst.txt") as f: diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index f520607af..479e195fa 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -271,7 +271,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 6c643c263..2a0879045 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -131,7 +131,9 @@ def main(): model.to(device) logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 968a9e9a8..e6471d2db 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -176,7 +176,9 @@ def main(): model = OnnxModel(params.nn_model) logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index bea520998..d4f3ae39f 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -148,13 +148,15 @@ def main(): num_classes=params.num_classes, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce1301..4ab685684 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,6 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -43,7 +42,7 @@ def save_checkpoint( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -102,7 +101,7 @@ def load_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, strict: bool = False, ) -> Dict[str, Any]: @@ -110,7 +109,7 @@ def load_checkpoint( TODO: document it """ logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") + checkpoint = torch.load(filename, map_location="cpu", weights_only=False) if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") @@ -163,7 +162,7 @@ def average_checkpoints( """ n = len(filenames) - avg = torch.load(filenames[0], map_location=device)["model"] + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -178,7 +177,9 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - state_dict = torch.load(filenames[i], map_location=device)["model"] + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)[ + "model" + ] for k in uniqued_names: avg[k] += state_dict[k] @@ -199,7 +200,7 @@ def save_checkpoint_with_global_batch_idx( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ): @@ -421,8 +422,10 @@ def average_checkpoints_with_averaged_model( device: Move checkpoints to this device before averaging. """ - state_dict_start = torch.load(filename_start, map_location=device) - state_dict_end = torch.load(filename_end, map_location=device) + state_dict_start = torch.load( + filename_start, map_location=device, weights_only=False + ) + state_dict_end = torch.load(filename_end, map_location=device, weights_only=False) average_period = state_dict_start["average_period"] diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index e5eaba619..d923e8842 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -631,7 +631,10 @@ def attach_diagnostics( ) module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) if type(module).__name__ in [ "Sigmoid", @@ -665,7 +668,10 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) - module.register_backward_hook(scalar_backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(scalar_backward_hook) + else: + module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/hooks.py b/icefall/hooks.py index 85583acbe..b543190be 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -77,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None: logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0178b80bf..023afb5a5 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -401,7 +407,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +476,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index c36abfcdf..acec95e94 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -341,7 +347,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +409,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/utils.py b/icefall/utils.py index aab479e56..022f83b3b 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,6 +26,7 @@ import pathlib import random import re import subprocess +import warnings from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -42,6 +43,7 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.signal_transforms import time_warp as time_warp_impl +from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -50,6 +52,48 @@ from icefall.checkpoint import average_checkpoints Pathlike = Union[str, Path] +TORCH_VERSION = version.parse(torch.__version__) + + +def create_grad_scaler(device="cuda", **kwargs): + """ + Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0. + Accepts all kwargs like: enabled, init_scale, growth_factor, etc. + + /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: + `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use + `torch.amp.GradScaler('cuda', args...)` instead. + """ + if TORCH_VERSION >= version.parse("2.3.0"): + from torch.amp import GradScaler + + return GradScaler(device=device, **kwargs) + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + return torch.cuda.amp.GradScaler(**kwargs) + + +@contextmanager +def torch_autocast(device_type="cuda", **kwargs): + """ + To fix the following warnings: + /icefall/egs/librispeech/ASR/zipformer/model.py:323: + FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. + Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=False): + """ + if TORCH_VERSION >= version.parse("2.3.0"): + # Use new unified API + with torch.amp.autocast(device_type=device_type, **kwargs): + yield + else: + # Suppress deprecation warning and use old CUDA-specific autocast + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + with torch.cuda.amp.autocast(**kwargs): + yield + # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 # Fixed: https://github.com/pytorch/pytorch/pull/49853 @@ -186,7 +230,7 @@ class AttributeDict(dict): tmp = {} for k, v in self.items(): # PosixPath is ont JSON serializable - if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + if isinstance(v, (pathlib.Path, torch.device, torch.dtype)): v = str(v) tmp[k] = v return json.dumps(tmp, indent=indent, sort_keys=True) @@ -1551,6 +1595,7 @@ def optim_step_and_measure_param_change( and the L2 norm of the original parameter. It is given by the formula: .. math:: + \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} diff --git a/requirements.txt b/requirements.txt index d97263142..885bf2fc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ flake8==5.0.4 # cantonese word segment support pycantonese==3.4.0 +packaging