diff --git a/.github/scripts/baker_zh/TTS/run-matcha.sh b/.github/scripts/baker_zh/TTS/run-matcha.sh new file mode 100755 index 000000000..150f023ae --- /dev/null +++ b/.github/scripts/baker_zh/TTS/run-matcha.sh @@ -0,0 +1,167 @@ +#!/usr/bin/env bash + +set -ex + +apt-get update +apt-get install -y sox + +python3 -m pip install numba conformer==0.3.2 diffusers librosa +python3 -m pip install jieba + + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/baker_zh/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir -p download + pushd download + wget -q https://huggingface.co/csukuangfj/tmp-files/resolve/main/BZNSYP-samples.tar.bz2 + tar xvf BZNSYP-samples.tar.bz2 + mv BZNSYP-samples BZNSYP + rm BZNSYP-samples.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./matcha + sed -i.bak s/1500/3/g ./train.py + git diff . + popd + + ./matcha/train.py \ + --exp-dir matcha/exp \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh matcha/exp +} + +function infer() { + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + + ./matcha/infer.py \ + --num-buckets 2 \ + --epoch 1 \ + --exp-dir ./matcha/exp \ + --tokens data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --vocoder ./generator_v2 \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav + + ls -lh *.wav + soxi ./generated.wav + rm -v ./generated.wav + rm -v generator_v2 +} + +function export_onnx() { + pushd matcha/exp + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/epoch-2000.pt + popd + + pushd data/fbank + rm -v *.json + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/cmvn.json + popd + + ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + + ls -lh *.onnx + + if false; then + # The CI machine does not have enough memory to run it + # + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + python3 ./matcha/export_onnx_hifigan.py + else + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx + fi + + ls -lh *.onnx + + python3 ./matcha/generate_lexicon.py + + for v in v1 v2 v3; do + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_$v.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav /icefall/generated-matcha-tts-steps-6-$v.wav + done + + ls -lh /icefall/*.wav + soxi /icefall/generated-matcha-tts-steps-6-*.wav + cp ./model-steps-*.onnx /icefall + + d=matcha-icefall-zh-baker + mkdir $d + cp -v data/tokens.txt $d + cp -v lexicon.txt $d + cp model-steps-3.onnx $d + pushd $d + curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2 + tar xvf dict.tar.bz2 + rm dict.tar.bz2 + + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst + +cat >README.md <= 2.3.0 torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] torch_version += ["2.4.1"] torch_version += ["2.5.0"] + torch_version += ["2.5.1"] + torch_version += ["2.6.0", "2.7.0", "2.7.1"] + + if specified_torch_version: + torch_version = [specified_torch_version] + + if specified_python_version: + python_version = [specified_python_version] matrix = [] for p in python_version: for t in torch_version: + if min_torch_version and version_gt(min_torch_version, t): + continue + # torchaudio <= 1.13.x supports only python <= 3.10 if version_gt(p, "3.10") and not version_gt(t, "2.0"): @@ -96,7 +127,12 @@ def get_matrix(): def main(): - matrix = get_matrix() + args = get_args() + matrix = get_matrix( + min_torch_version=args.min_torch_version, + specified_torch_version=args.torch_version, + specified_python_version=args.python_version, + ) print(json.dumps({"include": matrix})) diff --git a/.github/scripts/generate-piper-phonemize-page.py b/.github/scripts/generate-piper-phonemize-page.py index 3784d5fa5..e268acf03 100755 --- a/.github/scripts/generate-piper-phonemize-page.py +++ b/.github/scripts/generate-piper-phonemize-page.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -def main(): +def get_v1_2_0_files(): prefix = ( "https://github.com/csukuangfj/piper-phonemize/releases/download/2023.12.5/" ) @@ -19,9 +19,70 @@ def main(): "piper_phonemize-1.2.0-cp39-cp39-macosx_10_14_x86_64.whl", "piper_phonemize-1.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", ] + ans = [prefix + f for f in files] + ans.sort() + return ans + + +def get_v1_3_0_files(): + prefix = ( + "https://github.com/csukuangfj/piper-phonemize/releases/download/2025.06.23/" + ) + files = [ + "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp310-cp310-win_amd64.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp311-cp311-win_amd64.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_universal2.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_10_13_x86_64.whl", + "piper_phonemize-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp312-cp312-win_amd64.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_universal2.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", + "piper_phonemize-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp313-cp313-win_amd64.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp38-cp38-win_amd64.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_universal2.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", + "piper_phonemize-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", + "piper_phonemize-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", + "piper_phonemize-1.3.0-cp39-cp39-win_amd64.whl", + ] + ans = [prefix + f for f in files] + ans.sort() + return ans + + +def main(): + files = get_v1_3_0_files() + get_v1_2_0_files() + with open("piper_phonemize.html", "w") as f: - for file in files: - url = prefix + file + for url in files: + file = url.split("/")[-1] f.write(f'{file}
\n') diff --git a/.github/scripts/librispeech/ASR/run_rknn.sh b/.github/scripts/librispeech/ASR/run_rknn.sh new file mode 100755 index 000000000..bc7b00f0c --- /dev/null +++ b/.github/scripts/librispeech/ASR/run_rknn.sh @@ -0,0 +1,275 @@ +#!/usr/bin/env bash + +set -ex + +python3 -m pip install kaldi-native-fbank soundfile librosa + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +# https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed +# sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 +function export_2023_02_20() { + d=exp_2023_02_20 + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/data/lang_char_bpe/tokens.txt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/3.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-chinese-english-mixed/resolve/main/test_wavs/4.wav + ls -lh + popd + + ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d/ \ + --decode-chunk-len 64 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + ls -lh $d/ + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/0.wav + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/1.wav + + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + dst=sherpa-onnx-$platform-streaming-zipformer-bilingual-zh-en-2023-02-20 + mkdir -p $dst + + ./pruned_transducer_stateless7_streaming/export_rknn.py \ + --in-encoder $d/encoder-epoch-99-avg-1.onnx \ + --in-decoder $d/decoder-epoch-99-avg-1.onnx \ + --in-joiner $d/joiner-epoch-99-avg-1.onnx \ + --out-encoder $dst/encoder.rknn \ + --out-decoder $dst/decoder.rknn \ + --out-joiner $dst/joiner.rknn \ + --target-platform $platform 2>/dev/null + + ls -lh $dst/ + + ./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \ + --encoder $d/encoder-epoch-99-avg-1.onnx \ + --decoder $d/decoder-epoch-99-avg-1.onnx \ + --joiner $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + --wav $d/0.wav + + cp $d/tokens.txt $dst + mkdir $dst/test_wavs + cp $d/*.wav $dst/test_wavs + + tar cjvf $dst.tar.bz2 $dst + ls -lh $dst.tar.bz2 + mv $dst.tar.bz2 /icefall/ + ls -lh $dst/ + echo "---" + + rm -rf $dst + done +} + +# https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t +# sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16 +function export_2023_02_16() { + d=exp_2023_02_16 + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/data/lang_char_bpe/tokens.txt + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/1.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/2.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/3.wav + curl -SL -O https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t/resolve/main/test_wavs/4.wav + + ls -lh + + popd + + ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d/ \ + --decode-chunk-len 64 \ + \ + --num-encoder-layers 2,2,2,2,2 \ + --feedforward-dims 768,768,768,768,768 \ + --nhead 4,4,4,4,4 \ + --encoder-dims 256,256,256,256,256 \ + --attention-dims 192,192,192,192,192 \ + --encoder-unmasked-dims 192,192,192,192,192 \ + \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + + ls -lh $d/ + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/0.wav + + ./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ + --encoder-model-filename $d/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $d/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + $d/1.wav + + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + dst=sherpa-onnx-$platform-streaming-zipformer-small-bilingual-zh-en-2023-02-16 + mkdir -p $dst + + ./pruned_transducer_stateless7_streaming/export_rknn.py \ + --in-encoder $d/encoder-epoch-99-avg-1.onnx \ + --in-decoder $d/decoder-epoch-99-avg-1.onnx \ + --in-joiner $d/joiner-epoch-99-avg-1.onnx \ + --out-encoder $dst/encoder.rknn \ + --out-decoder $dst/decoder.rknn \ + --out-joiner $dst/joiner.rknn \ + --target-platform $platform 2>/dev/null + + ls -lh $dst/ + + ./pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py \ + --encoder $d/encoder-epoch-99-avg-1.onnx \ + --decoder $d/decoder-epoch-99-avg-1.onnx \ + --joiner $d/joiner-epoch-99-avg-1.onnx \ + --tokens $d/tokens.txt \ + --wav $d/0.wav + + cp $d/tokens.txt $dst + mkdir $dst/test_wavs + cp $d/*.wav $dst/test_wavs + + tar cjvf $dst.tar.bz2 $dst + ls -lh $dst.tar.bz2 + mv $dst.tar.bz2 /icefall/ + ls -lh $dst/ + echo "---" + + rm -rf $dst + done +} + +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-en-2023-06-26-english +function export_2023_06_26() { + d=exp_2023_06_26 + + mkdir $d + pushd $d + + curl -SL -O https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17/resolve/main/exp/pretrained.pt + mv pretrained.pt epoch-99.pt + + curl -SL -O https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17/resolve/main/data/lang_bpe_500/tokens.txt + + curl -SL -o 0.wav https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17/resolve/main/data/lang_bpe_500/tokens.txt + curl -SL -o 1.wav https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17/resolve/main/test_wavs/1221-135766-0001.wav + curl -SL -o 2.wav https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17/resolve/main/test_wavs/1221-135766-0002.wav + + ls -lh + + popd + + ./zipformer/export-onnx-streaming.py \ + --dynamic-batch 0 \ + --enable-int8-quantization 0 \ + --tokens $d/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $d \ + --use-ctc 0 \ + --use-transducer 1 \ + \ + --chunk-size 32 \ + --left-context-frames 128 \ + --causal 1 + + ls -lh $d/ + + for platform in rk3562 rk3566 rk3568 rk3576 rk3588; do + dst=sherpa-onnx-$platform-streaming-zipformer-en-2023-06-26 + mkdir -p $dst + + ./zipformer/export_rknn_transducer_streaming.py \ + --in-encoder $d/encoder-epoch-99-avg-1-chunk-32-left-128.onnx \ + --in-decoder $d/decoder-epoch-99-avg-1-chunk-32-left-128.onnx \ + --in-joiner $d/joiner-epoch-99-avg-1-chunk-32-left-128.onnx \ + --out-encoder $dst/encoder.rknn \ + --out-decoder $dst/decoder.rknn \ + --out-joiner $dst/joiner.rknn \ + --target-platform $platform + + ls -lh $dst/ + + cp $d/tokens.txt $dst + mkdir $dst/test_wavs + cp $d/*.wav $dst/test_wavs + + tar cjvf $dst.tar.bz2 $dst + ls -lh $dst.tar.bz2 + mv $dst.tar.bz2 /icefall/ + ls -lh $dst/ + echo "---" + + rm -rf $dst + done +} + +if [[ $rknn_toolkit2_version == "2.1.0" ]]; then + export_2023_02_16 + export_2023_02_20 +else + export_2023_06_26 +fi diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 37e1bc320..bfb37fb6d 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,8 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/inference.py \ + ./matcha/infer.py \ + --num-buckets 2 \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ @@ -76,7 +77,7 @@ function export_onnx() { popd pushd data/fbank - rm -v *.json + rm -fv *.json curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json popd @@ -89,7 +90,7 @@ function export_onnx() { ls -lh *.onnx if false; then - # THe CI machine does not have enough memory to run it + # The CI machine does not have enough memory to run it # curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 @@ -97,19 +98,54 @@ function export_onnx() { python3 ./matcha/export_onnx_hifigan.py else curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx fi ls -lh *.onnx - python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav + for v in v1 v2 v3; do + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_$v.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-steps-6-$v.wav + done ls -lh /icefall/*.wav - soxi /icefall/generated-matcha-tts-steps-6-v1.wav + soxi /icefall/generated-matcha-tts-steps-6-*.wav + + cp ./model-steps-*.onnx /icefall + + d=matcha-icefall-en_US-ljspeech + mkdir $d + cp -v data/tokens.txt $d + cp model-steps-3.onnx $d + pushd $d + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + +cat >README.md <README.md <README.md < log-$recipe.txt 2>&1 || true + + - uses: actions/upload-artifact@v4 + with: + name: log-${{ matrix.recipe }}-${{ matrix.rknn_toolkit2_version }} + path: ./log-*.txt + + - name: Display results + shell: bash + run: | + ls -lh *rk*.tar.bz2 || true + + - name: Release to GitHub + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models + + - name: Upload model to huggingface + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + + git clone https://huggingface.co/csukuangfj/sherpa-onnx-rknn-models huggingface + cd huggingface + + git fetch + git pull + git merge -m "merge remote" --ff origin main + dst=streaming-asr + mkdir -p $dst + cp ../*rk*.tar.bz2 $dst/ || true + + ls -lh $dst + git add . + git status + git commit -m "update models" + git status + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-rknn-models main || true + rm -rf huggingface diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 0681ece60..908c3cc43 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -36,7 +36,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.8] + python-version: ["3.10"] fail-fast: false steps: @@ -69,7 +69,7 @@ jobs: working-directory: ${{github.workspace}} run: | black --check --diff . - + - name: Run isort shell: bash working-directory: ${{github.workspace}} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c22f2edb5..ed0e62330 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,8 +30,8 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") echo "::set-output name=matrix::${MATRIX}" test: needs: generate_build_matrix diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml index a9d65516f..a5832df9d 100644 --- a/.github/workflows/yesno.yml +++ b/.github/workflows/yesno.yml @@ -30,8 +30,9 @@ jobs: id: set-matrix run: | # outputting for debugging purposes - python ./.github/scripts/docker/generate_build_matrix.py - MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10") + # MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --python-version "3.10" --min-torch-version "2.5.0") echo "::set-output name=matrix::${MATRIX}" yesno: needs: generate_build_matrix diff --git a/README.md b/README.md index 0e550ffb1..498f7e3b4 100644 --- a/README.md +++ b/README.md @@ -383,3 +383,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [vctk]: egs/vctk/TTS [ljspeech]: egs/ljspeech/TTS [libritts_tts]: egs/libritts/TTS + +## Acknowledgements + +Some contributors to this project were supported by Xiaomi Corporation. Others were supported by National Science Foundation CCRI award 2120435. This is not an exhaustive list of sources of support. diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst index 352a0dc90..a3dd9088f 100644 --- a/docs/source/for-dummies/model-export.rst +++ b/docs/source/for-dummies/model-export.rst @@ -41,7 +41,7 @@ To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the .. code-block:: python3 >>> import torch - >>> m = torch.load("tdnn/exp/pretrained.pt") + >>> m = torch.load("tdnn/exp/pretrained.pt", weights_only=False) >>> list(m.keys()) ['model'] >>> list(m["model"].keys()) diff --git a/egs/aidatatang_200zh/ASR/local/prepare_lang.py b/egs/aidatatang_200zh/ASR/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100755 --- a/egs/aidatatang_200zh/ASR/local/prepare_lang.py +++ b/egs/aidatatang_200zh/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index 17729e02e..d0dc36eff 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -224,7 +224,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index fa809b768..9060cdb26 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 355d1516d..38a94d6c6 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### Aishell training results (zipformer + CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### medium-scale model, number of model parameters: 66218471, i.e., 66.2 M + +| decoding method | test | dev | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 | +| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 | + +The training command using 2 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --enable-musan 0 \ + --exp-dir zipformer/exp \ + --max-duration 500 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --use-ctc 1 \ + --use-cr-ctc 1 \ + --use-transducer 0 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search ctc-prefix-beam-search; do + ./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 28 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + ### Aishell training results (Fine-tuning Pretrained Models) #### Whisper [./whisper](./whisper) diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 2cb476e20..90881ee40 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -503,7 +503,7 @@ def main(): else: H = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index af1171a6f..4caff4e16 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -249,7 +249,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -315,7 +315,7 @@ def main(): hyps = [[token_sym_table[i] for i in ids] for ids in token_ids] elif params.method in ["1best", "attention-decoder"]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 8a2daa93e..c88aea41a 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -516,7 +516,7 @@ def main(): else: H = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index c4aa98358..2bcf34de8 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -227,7 +227,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 60f014c48..457b564fe 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -72,7 +72,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -688,7 +694,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -989,7 +995,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index a4dda0d6d..3b9dad55e 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -184,7 +184,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -219,7 +219,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index 69fe3a40b..bf46a099b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -228,7 +228,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 7c23041ca..ad9f40e25 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -94,7 +94,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -797,7 +803,7 @@ def train_one_epoch( aishell = is_aishell(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 2dc835f3b..85a51278b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -94,6 +94,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py index 46f542641..40e0565bb 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -773,7 +773,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index 12004315b..1972d05c8 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -237,7 +237,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 811269989..a07216de8 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -87,6 +87,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -802,7 +803,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1203,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index f3b0f1e11..a8373d755 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -812,7 +818,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1202,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index aacbd153d..af8f5498a 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -92,7 +92,7 @@ class AishellAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -275,8 +275,7 @@ class AishellAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index 05e52f560..a6dfd8a75 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -337,7 +337,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 9754b4939..6cfe2de89 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -139,13 +139,13 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 540e7b61b..b52139d88 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -245,7 +245,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 4a4e9237c..56353712a 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -225,7 +225,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 66a91709e..28e8fbf28 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -225,7 +225,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 5350cb2b0..75d3c5a65 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -89,10 +89,10 @@ def average_checkpoints( """ n = len(filenames) - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + if "model" in torch.load(filenames[0], map_location=device, weights_only=False): + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device, weights_only=False) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -107,10 +107,10 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + if "model" in torch.load(filenames[i], map_location=device, weights_only=False): + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device, weights_only=False) for k in uniqued_names: avg[k] += state_dict[k] @@ -440,7 +440,7 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -469,7 +469,7 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d77f8c270..af4d6442e 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -81,6 +81,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -608,7 +609,7 @@ def train_one_epoch( ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..940aef1e5 --- /dev/null +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao, +# Zhifeng Han,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 28 \ + --exp-dir ./zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search +(2) ctc-prefix-beam-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 21 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-prefix-beam-search +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + (2) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "beam": 4, # for prefix-beam-search + } + ) + return params + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + batch: dict, +) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + hyp_tokens = [] + hyps = [] + + if params.decoding_method == "ctc-greedy-search": + hyp_tokens = ctc_greedy_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "ctc-prefix-beam-search": + hyp_tokens = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + + if params.decoding_method == "ctc-greedy-search": + return {"ctc-greedy-search" : hyps} + elif params.decoding_method == "ctc-prefix-beam-search": + return {"ctc-prefix-beam-search" : hyps} + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains 3 elements: + Respectively, they are cut_id, the reference transcript, and the predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + batch=batch, + ) + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results, char_level = True) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-prefix-beam-search", + ) # support ctc-greedy-search and ctc-prefix-beam-search + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + dev_cuts = aishell.valid_cuts() + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py index 538189e52..85b75c988 100755 --- a/egs/aishell/ASR/zipformer/decode.py +++ b/egs/aishell/ASR/zipformer/decode.py @@ -761,7 +761,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py index 1ec10b059..79376c638 100755 --- a/egs/aishell/ASR/zipformer/decode_bbpe.py +++ b/egs/aishell/ASR/zipformer/decode_bbpe.py @@ -783,7 +783,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py index 387bef98a..f2cddb9b1 100755 --- a/egs/aishell/ASR/zipformer/pretrained_bbpe.py +++ b/egs/aishell/ASR/zipformer/pretrained_bbpe.py @@ -298,7 +298,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index cd253c597..3104665b0 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -95,6 +96,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -239,6 +241,27 @@ def add_model_arguments(parser: argparse.ArgumentParser): chunk left-context frames will be chosen randomly from this list; else not relevant.""", ) + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -379,6 +402,27 @@ def get_parser(): with this parameter before adding to the final loss.""", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--seed", type=int, @@ -582,8 +626,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None model = AsrModel( encoder_embed=encoder_embed, @@ -593,9 +642,27 @@ def get_model(params: AttributeDict) -> nn.Module: encoder_dim=int(max(params.encoder_dim.split(","))), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, ) return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment def load_checkpoint_if_available( params: AttributeDict, @@ -722,6 +789,7 @@ def compute_loss( graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -738,6 +806,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ @@ -757,6 +827,21 @@ def compute_loss( y = graph_compiler.texts_to_ids(texts) y = k2.RaggedTensor(y).to(device) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): losses = model( x=feature, @@ -765,25 +850,40 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) - simple_loss, pruned_loss = losses[:2] + if params.use_ctc: + simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5] + else: + simple_loss, pruned_loss = losses[:2] - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -793,8 +893,13 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() return loss, info @@ -842,6 +947,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -868,6 +974,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -910,13 +1018,14 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1082,6 +1191,9 @@ def run(rank, world_size, args): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + logging.info(params) logging.info("About to create model") @@ -1090,6 +1202,12 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1199,6 +1317,7 @@ def run(rank, world_size, args): optimizer=optimizer, graph_compiler=graph_compiler, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) @@ -1226,6 +1345,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1292,6 +1412,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1302,13 +1423,14 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() @@ -1343,8 +1465,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 46a5506db..b9d7fe8ad 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -92,6 +92,7 @@ from icefall.utils import ( setup_logger, str2bool, tokenize_by_CJK_char, + torch_autocast, ) @@ -495,7 +496,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -895,7 +896,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, @@ -935,8 +936,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index f9cdfb621..4c564d04e 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -104,7 +104,7 @@ class AiShell2AsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -296,8 +296,7 @@ class AiShell2AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 9e44b4e34..93f75b36f 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -728,7 +728,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index f04632388..1002a6645 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -226,7 +226,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 8c7448d4c..84cd2ffca 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -90,7 +90,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -734,7 +740,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1062,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell4/ASR/local/prepare_lang.py b/egs/aishell4/ASR/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100755 --- a/egs/aishell4/ASR/local/prepare_lang.py +++ b/egs/aishell4/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index e8b7f71b7..f85d0552f 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index a354f761e..ab97f8677 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -83,7 +83,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -727,7 +733,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1034,7 +1040,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR/local/prepare_lang.py b/egs/alimeeting/ASR/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100755 --- a/egs/alimeeting/ASR/local/prepare_lang.py +++ b/egs/alimeeting/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index a738bb3fb..7566f9a5f 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -224,7 +224,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 30154291d..172d94862 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -79,7 +79,13 @@ from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -638,7 +644,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -912,7 +918,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 30879d8d2..855aeca12 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -73,7 +73,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -782,7 +788,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1127,7 +1133,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index 9999894d1..712855733 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -672,7 +672,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index d62cdadb7..8922717ef 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -71,7 +71,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -773,7 +779,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1134,7 +1140,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index adc6a8495..d5025b477 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1067,7 +1073,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1257,7 +1263,7 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index ac5b0dadc..35b3ced31 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -76,7 +76,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1058,7 +1064,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1248,7 +1254,7 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/audioset/AT/zipformer/pretrained.py b/egs/audioset/AT/zipformer/pretrained.py index bdbd799fa..8876b5889 100755 --- a/egs/audioset/AT/zipformer/pretrained.py +++ b/egs/audioset/AT/zipformer/pretrained.py @@ -141,7 +141,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 67c703364..caf8accb2 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -74,6 +74,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -799,7 +800,7 @@ def train_one_epoch( num_samples += batch_size try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1148,7 +1149,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/baker_zh/TTS/.gitignore b/egs/baker_zh/TTS/.gitignore new file mode 100644 index 000000000..6441cd500 --- /dev/null +++ b/egs/baker_zh/TTS/.gitignore @@ -0,0 +1,6 @@ +path.sh +*.onnx +*.wav +generator_v1 +generator_v2 +generator_v3 diff --git a/egs/baker_zh/TTS/README.md b/egs/baker_zh/TTS/README.md new file mode 100644 index 000000000..7120c6f79 --- /dev/null +++ b/egs/baker_zh/TTS/README.md @@ -0,0 +1,146 @@ +# Introduction + +It is for the dataset from +https://en.data-baker.com/datasets/freeDatasets/ + +The dataset contains 10000 Chinese sentences of a native Chinese female speaker, +which is about 12 hours. + + +**Note**: The dataset is for non-commercial use only. + + +# matcha + +[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) + +Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27). +The pull-request for this recipe can be found at + +The training command is given below: +```bash +python3 ./matcha/train.py \ + --exp-dir ./matcha/exp-1/ \ + --num-workers 4 \ + --world-size 1 \ + --num-epochs 2000 \ + --max-duration 1200 \ + --bucketing-sampler 1 \ + --start-epoch 1 +``` + +To inference, use: + +```bash +# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3 + +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +``` + +```bash +soxi ./generated.wav +``` + +prints: +``` +Input File : './generated.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors +File Size : 763k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024 + + +To export the checkpoint to onnx: +```bash +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json +``` + +The above command generates the following files: +``` +-rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx +``` + +where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. + +**HINT**: If you get the following error while running `export_onnx.py`: + +``` +torch.onnx.errors.UnsupportedOperatorError: Exporting the operator +'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. +``` + +please use `torch>=2.2.0`. + +To export the Hifigan vocoder to onnx, please use: + +```bash +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + +python3 ./matcha/export_onnx_hifigan.py +``` + +The above command generates 3 files: + + - hifigan_v1.onnx + - hifigan_v2.onnx + - hifigan_v3.onnx + +**HINT**: You can download pre-exported hifigan ONNX models from + + +To use the generated onnx files to generate speech from text, please run: + +```bash + +# First, generate ./lexicon.txt +python3 ./matcha/generate_lexicon.py + +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \ + --output-wav ./1.wav +``` + +```bash +soxi ./1.wav + +Input File : './1.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors +File Size : 722k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e + diff --git a/egs/baker_zh/TTS/local/audio.py b/egs/baker_zh/TTS/local/audio.py new file mode 120000 index 000000000..b70d91c92 --- /dev/null +++ b/egs/baker_zh/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py new file mode 100755 index 000000000..deb344d14 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the baker-zh dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_baker_zh(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + if not torch.cuda.is_available(): + config.device = "cpu" + + prefix = "baker_zh" + suffix = "jsonl.gz" + + extractor = MatchaFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts.{suffix}" + logging.info(f"Processing {cuts_filename}") + cut_set = load_manifest(src_dir / cuts_filename).resample(22050) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_baker_zh(args.num_jobs) diff --git a/egs/baker_zh/TTS/local/compute_fbank_statistics.py b/egs/baker_zh/TTS/local/compute_fbank_statistics.py new file mode 120000 index 000000000..fd1d8b52e --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_statistics.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/compute_fbank_statistics.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py new file mode 100755 index 000000000..bf59cb466 --- /dev/null +++ b/egs/baker_zh/TTS/local/convert_text_to_tokens.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +import argparse +import re +from typing import List + +import jieba +from lhotse import load_manifest +from pypinyin import Style, lazy_pinyin, load_phrases_dict + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + +whiter_space_re = re.compile(r"\s+") + +punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ("B", "逼"), + ("P", "批"), + ] +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--in-file", + type=str, + required=True, + help="Input cutset.", + ) + + parser.add_argument( + "--out-file", + type=str, + required=True, + help="Output cutset.", + ) + + return parser + + +def normalize_white_spaces(text): + return whiter_space_re.sub(" ", text) + + +def normalize_punctuations(text): + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +def split_text(text: str) -> List[str]: + """ + Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?' + Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?'] + """ + text = text.lower() + text = normalize_white_spaces(text) + text = normalize_punctuations(text) + ans = [] + + for seg in jieba.cut(text): + if seg in ",.!?:\"'": + ans.append(seg) + elif seg == " " and len(ans) > 0: + if ord("a") <= ord(ans[-1][-1]) <= ord("z"): + ans[-1] += seg + elif ord("a") <= ord(seg[0]) <= ord("z"): + if len(ans) == 0: + ans.append(seg) + continue + + if ans[-1][-1] == " ": + ans[-1] += seg + continue + + ans.append(seg) + else: + ans.append(seg) + + ans = [s.strip() for s in ans] + return ans + + +def main(): + args = get_parser().parse_args() + cuts = load_manifest(args.in_file) + for c in cuts: + assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions) + text = c.supervisions[0].normalized_text + + text_list = split_text(text) + tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) + + c.tokens = tokens + + cuts.to_file(args.out_file) + + print(f"saved to {args.out_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/fbank.py b/egs/baker_zh/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/baker_zh/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py new file mode 100755 index 000000000..b2abe1a71 --- /dev/null +++ b/egs/baker_zh/TTS/local/generate_tokens.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +""" +This file generates the file tokens.txt. + +Usage: + +python3 ./local/generate_tokens.py > data/tokens.txt +""" + + +import argparse +from typing import List + +import jieba +from pypinyin import Style, lazy_pinyin, pinyin_dict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to to save tokens.txt.", + ) + + return parser + + +def generate_token_list() -> List[str]: + token_set = set() + + word_dict = pinyin_dict.pinyin_dict + i = 0 + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + token_set.add(t) + + no_digit = set() + for t in token_set: + if t[-1] not in "1234": + no_digit.add(t) + else: + no_digit.add(t[:-1]) + + no_digit.add("dei") + no_digit.add("tou") + no_digit.add("dia") + + for t in no_digit: + token_set.add(t) + for i in range(1, 5): + token_set.add(f"{t}{i}") + + ans = list(token_set) + ans.sort() + + punctuations = list(",.!?:\"'") + ans = punctuations + ans + + # use ID 0 for blank + # Use ID 1 of _ for padding + ans.insert(0, " ") + ans.insert(1, "_") # + + return ans + + +def main(): + args = get_parser().parse_args() + token_list = generate_token_list() + with open(args.tokens, "w", encoding="utf-8") as f: + for indx, token in enumerate(token_list): + f.write(f"{token} {indx}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py new file mode 100755 index 000000000..4e31028f7 --- /dev/null +++ b/egs/baker_zh/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/baker_zh_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/baker_zh/TTS/matcha/__init__.py b/egs/baker_zh/TTS/matcha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/baker_zh/TTS/matcha/audio.py b/egs/baker_zh/TTS/matcha/audio.py new file mode 120000 index 000000000..62d3959d6 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/audio.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py new file mode 100755 index 000000000..28efbfe61 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py + +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=2000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, num_steps: int = 5): + super().__init__() + self.model = model + self.num_steps = num_steps + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + noise_scale: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + noise_scale: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + audio: (batch_size, num_samples) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=self.num_steps, + temperature=noise_scale, + length_scale=length_scale, + )["mel"] + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, noise_scale, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "noise_scale", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "Chinese", + "has_espeak": 0, + "n_speakers": 1, + "jieba": 1, + "sample_rate": 22050, + "version": 1, + "pad_id": params.pad_id, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "baker-zh", + "use_eos_bos": 0, + "dataset_url": "https://www.data-baker.com/open_source.html", + "dataset_comment": "The dataset is for non-commercial use only.", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py new file mode 120000 index 000000000..d0b8af15b --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/export_onnx_hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/fbank.py b/egs/baker_zh/TTS/matcha/fbank.py new file mode 120000 index 000000000..3cfb7fe3f --- /dev/null +++ b/egs/baker_zh/TTS/matcha/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/generate_lexicon.py b/egs/baker_zh/TTS/matcha/generate_lexicon.py new file mode 100755 index 000000000..f26f28e91 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/generate_lexicon.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import jieba +from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + + +def main(): + filename = "lexicon.txt" + tokens = "./data/tokens.txt" + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + + f.write(f"{w} {tokens}\n") + + for key in phrases: + tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) + tokens = " ".join(tokens) + + f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/hifigan b/egs/baker_zh/TTS/matcha/hifigan new file mode 120000 index 000000000..c0a91072c --- /dev/null +++ b/egs/baker_zh/TTS/matcha/hifigan @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/hifigan \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/infer.py b/egs/baker_zh/TTS/matcha/infer.py new file mode 100755 index 000000000..142d9fdfe --- /dev/null +++ b/egs/baker_zh/TTS/matcha/infer.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +""" +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +""" + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from local.convert_text_to_tokens import split_text +from pypinyin import Style, lazy_pinyin +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import BakerZhTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for baker_zh)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + text = split_text(text) + tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True) + + x = tokenizer.texts_to_token_ids([tokens]) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + baker_zh = BakerZhTtsDataModule(args) + + test_cuts = baker_zh.test_cuts() + test_dl = baker_zh.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/model.py b/egs/baker_zh/TTS/matcha/model.py new file mode 120000 index 000000000..8a1b812a9 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/model.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/model.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/models b/egs/baker_zh/TTS/matcha/models new file mode 120000 index 000000000..09a862665 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/models @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/models \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/monotonic_align b/egs/baker_zh/TTS/matcha/monotonic_align new file mode 120000 index 000000000..d0a0dd6b5 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/onnx_pretrained.py b/egs/baker_zh/TTS/matcha/onnx_pretrained.py new file mode 100755 index 000000000..f6b7f7cae --- /dev/null +++ b/egs/baker_zh/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./b.wav +""" + +import argparse +import datetime as dt +import logging +import re +from typing import Dict, List + +import jieba +import onnxruntime as ort +import soundfile as sf +import torch +from infer import load_vocoder +from utils import intersperse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--lexicon", + type=str, + required=True, + help="Path to the lexicon.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + # audio: (batch_size, num_samples) + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + noise_scale = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + # mel: (batch_size, feat_dim, num_frames) + + return torch.from_numpy(mel) + + +def read_tokens(filename: str) -> Dict[str, int]: + token2id = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + assert token not in token2id, token + token2id[token] = idx + return token2id + + +def read_lexicon(filename: str) -> Dict[str, List[str]]: + word2token = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + w = info[0] + tokens = info[1:] + word2token[w] = tokens + return word2token + + +def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]: + if word in word2tokens: + return word2tokens[word] + + if len(word) == 1: + return [] + + ans = [] + for w in word: + t = convert_word_to_tokens(word2tokens, w) + ans.extend(t) + return ans + + +def normalize_text(text): + whiter_space_re = re.compile(r"\s+") + + punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ] + ] + + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +@torch.no_grad() +def main(): + params = get_parser().parse_args() + logging.info(vars(params)) + token2id = read_tokens(params.tokens) + word2tokens = read_lexicon(params.lexicon) + + text = normalize_text(params.input_text) + seg = jieba.cut(text) + tokens = [] + for s in seg: + if s in token2id: + tokens.append(s) + continue + + t = convert_word_to_tokens(word2tokens, s) + if t: + tokens.extend(t) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + + x = [] + for t in tokens: + if t in token2id: + x.append(token2id[t]) + + x = intersperse(x, item=token2id["_"]) + + x = torch.tensor(x, dtype=torch.int64).unsqueeze(0) + + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + sample_rate = model.sample_rate + + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) + + # skip denoiser + sf.write(params.output_wav, audio, sample_rate, "PCM_16") + logging.info(f"Saved to {params.output_wav}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py new file mode 100644 index 000000000..dda82c29d --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tokenizer.py @@ -0,0 +1,119 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import logging +from typing import Dict, List + +import tacotron_cleaner.cleaners + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + +from utils import intersperse + + +# This tokenizer supports both English and Chinese. +# We assume you have used +# ../local/convert_text_to_tokens.py +# to process your text +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.space_id = self.token2id[" "] # word separator (whitespace) + + self.vocab_size = len(self.token2id) + + def texts_to_token_ids( + self, + sentence_list: List[List[str]], + intersperse_blank: bool = True, + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + sentence_list: + A list of sentences. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for sentence in sentence_list: + tokens_list = [] + for word in sentence: + if word in self.token2id: + tokens_list.append(word) + continue + + tmp_tokens_list = phonemize_espeak(word, lang) + for t in tmp_tokens_list: + tokens_list.extend(t) + + token_ids = [] + for t in tokens_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t} {sentence}") + continue + + if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: + continue + + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + +def test_tokenizer(): + import jieba + from pypinyin import Style, lazy_pinyin + + tokenizer = Tokenizer("data/tokens.txt") + text1 = "今天is Monday, tomorrow is 星期二" + text2 = "你好吗? 我很好, how about you?" + + text1 = list(jieba.cut(text1)) + text2 = list(jieba.cut(text2)) + tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True) + tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True) + print(tokens1) + print(tokens2) + + ids = tokenizer.texts_to_token_ids([tokens1, tokens2]) + print(ids) + + +if __name__ == "__main__": + test_tokenizer() diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py new file mode 100755 index 000000000..ed2ba49b9 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/train.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + + +import argparse +import json +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.utils import fix_random_seed +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import BakerZhTtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": 0, + "mel_std": 1, + } + ) + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "baker-zh", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sampling_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": get_data_statistics(), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_spks": 1, # for baker-zh. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 10, + "valid_interval": 1500, + "env_info": get_env_info(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model_args) + return m + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): + """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N + + batch_size = len(batch["tokens"]) + + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + try: + with autocast(enabled=params.use_fp16): + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + loss = sum(losses.values()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + logging.info("About to create datamodule") + + baker_zh = BakerZhTtsDataModule(args) + + train_cuts = baker_zh.train_cuts() + train_dl = baker_zh.train_dataloaders(train_cuts) + + valid_cuts = baker_zh.valid_cuts() + valid_dl = baker_zh.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/baker_zh/TTS/matcha/tts_datamodule.py b/egs/baker_zh/TTS/matcha/tts_datamodule.py new file mode 100644 index 000000000..d2bdfb96c --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tts_datamodule.py @@ -0,0 +1,340 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class BakerZhTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" + ) diff --git a/egs/baker_zh/TTS/matcha/utils.py b/egs/baker_zh/TTS/matcha/utils.py new file mode 120000 index 000000000..ceaaea196 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh new file mode 100755 index 000000000..e15e3d850 --- /dev/null +++ b/egs/baker_zh/TTS/prepare.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +mkdir -p $dl_dir + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib (used by ./matcha)" + for recipe in matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/BANSYP contains the following 3 directories + + # ls -lh $dl_dir/BZNSYP/ + # total 0 + # drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling + # drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling + # drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave + + # If you have trouble accessing huggingface.co, please use + # + # cd $dl_dir + # wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2 + # tar xf BZNSYP.tar.bz2 + # cd .. + + # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink + # + # ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP + # + if [ ! -d $dl_dir/BZNSYP/Wave ]; then + lhotse download baker-zh $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare baker-zh manifest" + # We assume that you have downloaded the baker corpus + # to $dl_dir/BZNSYP + mkdir -p data/manifests + if [ ! -e data/manifests/.baker-zh.done ]; then + lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests + touch data/manifests/.baker-zh.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Generate tokens.txt" + if [ ! -e data/tokens.txt ]; then + python3 ./local/generate_tokens.py --tokens data/tokens.txt + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Generate raw cutset" + if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then + lhotse cut simple \ + -r ./data/manifests/baker_zh_recordings_all.jsonl.gz \ + -s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \ + ./data/manifests/baker_zh_cuts_raw.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Convert text to tokens" + if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then + python3 ./local/convert_text_to_tokens.py \ + --in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \ + --out-file ./data/manifests/baker_zh_cuts.jsonl.gz + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.baker-zh.done ]; then + ./local/compute_fbank_baker_zh.py + touch data/fbank/.baker-zh.done + fi + + if [ ! -e data/fbank/.baker-zh-validated.done ]; then + log "Validating data/fbank for baker-zh (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/baker_zh_cuts.jsonl.gz + touch data/fbank/.baker-zh-validated.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.baker_zh_split.done ]; then + lhotse subset --last 600 \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_test.jsonl.gz + + rm data/fbank/baker_zh_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 )) + + lhotse subset --first $n \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_train.jsonl.gz + + touch data/fbank/.baker_zh_split.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 6: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/baker_zh/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py index 6512aa68b..76b7afcab 100755 --- a/egs/commonvoice/ASR/local/compile_hlg.py +++ b/egs/commonvoice/ASR/local/compile_hlg.py @@ -73,11 +73,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lang_dir}/lm/{lm}.pt") + d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py index 76dacb5b2..2a17e91c6 100755 --- a/egs/commonvoice/ASR/local/compile_lg.py +++ b/egs/commonvoice/ASR/local/compile_lg.py @@ -68,11 +68,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: An FSA representing LG. """ lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"{lang_dir}/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lang_dir}/lm/{lm}.pt") + d = torch.load(f"{lang_dir}/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py index 52b2fbcab..00f6616a4 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/decode.py @@ -910,7 +910,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py index b6e2451e8..eee563e70 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 5e98084ec..7a859ff38 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -88,6 +88,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -825,7 +826,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1220,7 +1221,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py index 7ae4f1894..6dfb32728 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -767,7 +767,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 976004eca..1b187da1a 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -90,6 +90,7 @@ from icefall.utils import ( filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -626,7 +627,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -895,7 +896,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1293,7 +1294,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py index 3fd14aa47..1a104442f 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py @@ -25,7 +25,7 @@ Usage: --exp-dir ./pruned_transducer_stateless7/exp It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. +You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt", weights_only=False)`. (2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ @@ -35,7 +35,7 @@ You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`. --exp-dir ./pruned_transducer_stateless7/exp It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. +You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt", weights_only=False)`. (3) use the original model with checkpoint exp_dir/epoch-xxx.pt ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ @@ -45,7 +45,7 @@ You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`. --exp-dir ./pruned_transducer_stateless7/exp It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. +You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`. (4) use the original model with checkpoint exp_dir/checkpoint-iter.pt ./pruned_transducer_stateless7/generate_model_from_checkpoint.py \ @@ -55,7 +55,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`. --exp-dir ./pruned_transducer_stateless7/exp It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. +You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`. """ diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 67e1a8133..f1e9b6d43 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -81,7 +81,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -840,7 +846,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1237,7 +1243,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/decode.py b/egs/commonvoice/ASR/zipformer/decode.py index 7fd6d0ccd..5e3cbaf92 100755 --- a/egs/commonvoice/ASR/zipformer/decode.py +++ b/egs/commonvoice/ASR/zipformer/decode.py @@ -987,7 +987,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/commonvoice/ASR/zipformer/decode_char.py b/egs/commonvoice/ASR/zipformer/decode_char.py index 1f8c9c7c6..8a814122d 100755 --- a/egs/commonvoice/ASR/zipformer/decode_char.py +++ b/egs/commonvoice/ASR/zipformer/decode_char.py @@ -756,7 +756,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 271014db0..c6940def5 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -969,7 +970,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1365,7 +1366,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 0aa7856cc..f44232c0e 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -604,7 +605,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -784,7 +785,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py index f5a1d750d..8c8e7ab83 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -791,7 +791,7 @@ def main(): if params.decoding_graph: decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) + torch.load(params.decoding_graph, map_location=device, weights_only=False) ) elif "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -800,7 +800,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py index 66fbae378..3a7a05820 100644 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -239,7 +239,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 6a249dd3f..fa4f41483 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -561,7 +561,7 @@ def main(): decoding_graph = None if params.decoding_graph: decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) + torch.load(params.decoding_graph, map_location=device, weights_only=False) ) elif params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index ef7ea9013..5862cd660 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -83,7 +83,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LOG_EPS = math.log(1e-10) @@ -838,7 +844,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1245,7 +1251,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py index a7df8f966..803164d82 100755 --- a/egs/fluent_speech_commands/SLU/local/compile_hlg.py +++ b/egs/fluent_speech_commands/SLU/local/compile_hlg.py @@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) logging.info("Loading G.fst.txt") with open(lang_dir / "G.fst.txt") as f: diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py index 2a71dcf81..72b9bf1c3 100755 --- a/egs/fluent_speech_commands/SLU/local/prepare_lang.py +++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py @@ -14,7 +14,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared index 9115c7e17..4cbd91a7e 120000 --- a/egs/fluent_speech_commands/SLU/shared +++ b/egs/fluent_speech_commands/SLU/shared @@ -1 +1 @@ -../../../icefall/shared/ +../../../icefall/shared \ No newline at end of file diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index d7035a1f8..47f35174f 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -589,7 +589,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -628,7 +628,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/gigaspeech/ASR/local/compile_lg.py b/egs/gigaspeech/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/gigaspeech/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py index 9e0df0989..14353008c 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py @@ -32,13 +32,21 @@ torch.set_num_interop_threads(1) def compute_fbank_gigaspeech(): in_out_dir = Path("data/fbank") + # number of workers in dataloader num_workers = 20 # number of seconds in a batch batch_duration = 1000 - subsets = ("L", "M", "S", "XS", "DEV", "TEST") + subsets = ( + "DEV", + "TEST", + # "L", + # "M", + # "S", + # "XS", + ) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 51cd59078..c1645f7cc 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -18,7 +18,7 @@ import argparse import logging -from datetime import datetime +import os from pathlib import Path import torch @@ -32,7 +32,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def get_parser(): +def get_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) @@ -71,17 +71,15 @@ def get_parser(): default=-1, help="Stop processing pieces until this number (exclusive).", ) - return parser + return parser.parse_args() def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits - output_dir = f"data/fbank/XL_split" + output_dir = "data/fbank/gigaspeech_XL_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" - num_digits = 8 # num_digits is fixed by lhotse split-lazy - start = args.start stop = args.stop if stop < start: @@ -95,6 +93,7 @@ def compute_fbank_gigaspeech_splits(args): extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") + num_digits = 8 # num_digits is fixed by lhotse split-lazy for i in range(start, stop): idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") @@ -105,15 +104,22 @@ def compute_fbank_gigaspeech_splits(args): continue raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) logging.info("Computing features") + filename = output_dir / f"gigaspeech_feats_XL_{idx}.lca" + if filename.exists(): + logging.info(f"Removing {filename}") + os.remove(str(filename)) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/gigaspeech_feats_{idx}", + storage_path=f"{output_dir}/gigaspeech_feats_XL_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, @@ -130,29 +136,10 @@ def compute_fbank_gigaspeech_splits(args): def main(): - now = datetime.now() - date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - - log_filename = "log-compute_fbank_gigaspeech_splits" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - log_filename = f"{log_filename}-{date_time}" - - logging.basicConfig( - filename=log_filename, - format=formatter, - level=logging.INFO, - filemode="w", - ) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter(formatter)) - logging.getLogger("").addHandler(console) - - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() compute_fbank_gigaspeech_splits(args) diff --git a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index a31685211..5bc881ab3 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -30,18 +30,6 @@ from icefall.utils import str2bool # https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Whether to use speed perturbation.", - ) - - return parser.parse_args() - - def normalize_text( utt: str, punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), @@ -57,7 +45,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_giga_speech(args): +def preprocess_gigaspeech(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -66,10 +54,10 @@ def preprocess_giga_speech(args): "DEV", "TEST", "XL", - "L", - "M", - "S", - "XS", + # "L", + # "M", + # "S", + # "XS", ) logging.info("Loading manifest (may take 4 minutes)") @@ -110,17 +98,7 @@ def preprocess_giga_speech(args): recordings=m["recordings"], supervisions=m["supervisions"], ) - # Run data augmentation that needs to be done in the - # time domain. - if partition not in ["DEV", "TEST"]: - if args.perturb_speed: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) @@ -129,8 +107,7 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - preprocess_giga_speech(args) + preprocess_gigaspeech() if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/local/validate_bpe_lexicon.py b/egs/gigaspeech/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/gigaspeech/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index 219197e13..ef6a667f9 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -6,12 +6,24 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=15 -stage=0 -stop_stage=100 -# Split XL subset to a number of pieces (about 2000) -# This is to avoid OOM during feature extraction. -num_per_split=50 +# Run step 0 to step 8 by default +stage=0 +stop_stage=8 + +# Compute fbank features for a subset of splits from `start` (inclusive) to `stop` (exclusive) +start=0 +stop=-1 # -1 means until the end + +# Note: This script just prepares the minimal requirements needed by a +# transducer training with bpe units. +# +# If you want to use ngram, please continue running prepare_lm.sh after +# you succeed in running this script. +# +# This script also contains the steps to generate phone based units, but they +# will not run automatically, you can generate the phone based units by +# bash prepare.sh --stage 9 --stop-stage 9 # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -34,9 +46,10 @@ num_per_split=50 # This directory contains the following directories downloaded from # http://www.openslr.org/17/ # -# - music -# - noise -# - speech +# - music +# - noise +# - speech + dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -45,6 +58,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( + # 5000 + # 2000 + # 1000 500 ) @@ -58,10 +74,12 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "Running prepare.sh" + log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "stage -1: Download LM" + log "Stage -1: Download LM" # We assume that you have installed the git-lfs, if not, you could install it # using: `sudo apt-get install git-lfs && git-lfs install` [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm @@ -78,7 +96,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # If you have pre-downloaded it to /path/to/GigaSpeech, # you can create a symlink # - # ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech + # ln -svf /path/to/GigaSpeech $dl_dir/GigaSpeech # if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then # Check credentials. @@ -88,32 +106,37 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then echo " and save it to $dl_dir/password." exit 1; fi + PASSWORD=`cat $dl_dir/password 2>/dev/null` if [ -z "$PASSWORD" ]; then echo "$0: Error, $dl_dir/password is empty." exit 1; fi + PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1` if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then echo "$0: Error, invalid $dl_dir/password." exit 1; fi + # Download XL, DEV and TEST sets by default. - lhotse download gigaspeech --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ + # Support hosts: + # 1. oss + # 2. tsinghua + # 3. speechocean + # 4. magicdata + lhotse download gigaspeech \ + --host magicdata \ --subset DEV \ --subset TEST \ - --host tsinghua \ + --subset XL \ $dl_dir/password $dl_dir/GigaSpeech fi # If you have pre-downloaded it to /path/to/musan, # you can create a symlink # - # ln -sfv /path/to/musan $dl_dir/ + # ln -svf /path/to/musan $dl_dir/ # if [ ! -d $dl_dir/musan ]; then lhotse download musan $dl_dir @@ -125,11 +148,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the GigaSpeech corpus # to $dl_dir/GigaSpeech mkdir -p data/manifests - lhotse prepare gigaspeech --subset XL \ - --subset L \ - --subset M \ - --subset S \ - --subset XS \ + lhotse prepare gigaspeech \ + --subset XL \ --subset DEV \ --subset TEST \ -j $nj \ @@ -147,19 +167,20 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "State 3: Preprocess GigaSpeech manifest" if [ ! -f data/fbank/.preprocess_complete ]; then - python3 ./local/preprocess_gigaspeech.py - touch data/fbank/.preprocess_complete + python3 ./local/preprocess_gigaspeech.py + touch data/fbank/.preprocess_complete fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech." + log "Stage 4: Compute features for DEV, TEST, L, M, S, and XS subsets of GigaSpeech." python3 ./local/compute_fbank_gigaspeech.py fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split XL subset into pieces (may take 30 minutes)" - split_dir=data/fbank/XL_split + log "Stage 5: Split XL subset into pieces (may take 5 minutes)" + num_per_split=50 + split_dir=data/fbank/gigaspeech_XL_split if [ ! -f $split_dir/.split_completed ]; then lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $num_per_split touch $split_dir/.split_completed @@ -168,82 +189,63 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Compute features for XL" - num_splits=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l) + split_dir=data/fbank/gigaspeech_XL_split + num_splits=$(find $split_dir -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l) python3 ./local/compute_fbank_gigaspeech_splits.py \ --num-workers 20 \ --batch-duration 600 \ - --num-splits $num_splits + --num-splits $num_splits \ + --start $start \ + --stop $stop fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Combine features for XL (may take 3 hours)" - if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then - pieces=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL.*.jsonl.gz") - lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compute fbank for musan" + log "Stage 7: Compute fbank for musan" mkdir -p data/fbank ./local/compute_fbank_musan.py fi -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare transcript_words.txt and words.txt" - lang_dir=data/lang_phone - mkdir -p $lang_dir - if [ ! -f $lang_dir/transcript_words.txt ]; then - gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ - | jq '.text' \ - | sed 's/"//g' \ - > $lang_dir/transcript_words.txt +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir - # Delete utterances with garbage meta tags - garbage_utterance_tags=" " - for tag in $garbage_utterance_tags; do - sed -i "/${tag}/d" $lang_dir/transcript_words.txt - done + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ + | jq '.text' \ + | sed 's/"//g' \ + > $lang_dir/transcript_words.txt - # Delete punctuations in utterances - punctuation_tags=" " - for tag in $punctuation_tags; do - sed -i "s/${tag}//g" $lang_dir/transcript_words.txt - done + # Delete utterances with garbage meta tags + garbage_utterance_tags=" " + for tag in $garbage_utterance_tags; do + sed -i "/${tag}/d" $lang_dir/transcript_words.txt + done - # Ensure space only appears once - sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt - fi + # Delete punctuations in utterances + punctuation_tags=" " + for tag in $punctuation_tags; do + sed -i "s/${tag}//g" $lang_dir/transcript_words.txt + done - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + done fi -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare phone based lang" +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare phone based lang" lang_dir=data/lang_phone mkdir -p $lang_dir @@ -255,93 +257,3 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then ./local/prepare_lang.py --lang-dir $lang_dir fi fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/{words.txt,transcript_words.txt} $lang_dir - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - fi - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare bigram P" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/3gram_pruned_1e7.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/4gram.arpa > data/lm/G_4_gram.fst.txt - fi -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - done -fi diff --git a/egs/gigaspeech/ASR/prepare_lm.sh b/egs/gigaspeech/ASR/prepare_lm.sh new file mode 100755 index 000000000..3fcf899a3 --- /dev/null +++ b/egs/gigaspeech/ASR/prepare_lm.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# This script generates Ngram LM and related files needed by decoding. + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/lm +# This directory contains the language model downloaded from +# https://huggingface.co/wgb14/gigaspeech_lm +# +# - 3gram_pruned_1e7.arpa.gz +# - 4gram.arpa.gz +# - lexicon.txt + +. prepare.sh --stage -1 --stop-stage 9 || exit 1 + +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +log "Running prepare_lm.sh" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare BPE based lexicon" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare word-level G" + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3gram_pruned_1e7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4gram.arpa > data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compile LG" + # It is used for for RNN-T fast_beam_search decoding + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 40339365c..c06e8f461 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -101,7 +101,7 @@ class GigaSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -294,8 +294,7 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index a7772b62f..56371e59a 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -675,7 +681,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -944,7 +950,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 0501461cd..93a41b27a 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=100, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -219,6 +219,8 @@ class GigaSpeechAsrDataModule: self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -311,9 +313,10 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -321,6 +324,8 @@ class GigaSpeechAsrDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -344,7 +349,12 @@ class GigaSpeechAsrDataModule: return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -369,8 +379,10 @@ class GigaSpeechAsrDataModule: cuts_valid, max_duration=self.args.max_duration, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, + buffer_size=self.args.num_buckets * 5000, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -410,7 +422,7 @@ class GigaSpeechAsrDataModule: logging.info(f"About to get train {self.args.subset} cuts") if self.args.subset == "XL": filenames = glob.glob( - f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" + f"{self.args.manifest_dir}/gigaspeech_XL_split/gigaspeech_cuts_XL.*.jsonl.gz" ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index 651f20cb6..c28abf020 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -668,7 +668,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -707,7 +707,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/gigaspeech/ASR/zipformer/decode.py b/egs/gigaspeech/ASR/zipformer/decode.py index 3a0c71484..cbd54ad9e 100755 --- a/egs/gigaspeech/ASR/zipformer/decode.py +++ b/egs/gigaspeech/ASR/zipformer/decode.py @@ -1000,7 +1000,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 4c122effe..d586fc26a 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -958,7 +959,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1201,12 +1202,19 @@ def run(rank, world_size, args): sampler_state_dict = None train_dl = gigaspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict + train_cuts, + sampler_state_dict=sampler_state_dict, + world_size=world_size, + rank=rank, ) valid_cuts = gigaspeech.dev_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) - valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + valid_dl = gigaspeech.valid_dataloaders( + valid_cuts, + world_size=world_size, + rank=rank, + ) if not params.print_diagnostics and params.scan_for_oom_batches: scan_pessimistic_batches_for_oom( @@ -1317,7 +1325,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py deleted file mode 100644 index ccc602404..000000000 --- a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import glob -import inspect -import logging -import re -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import lhotse -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class GigaSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - # GigaSpeech specific arguments - group.add_argument( - "--subset", - type=str, - default="XL", - help="Select the GigaSpeech subset (XS|S|M|L|XL)", - ) - group.add_argument( - "--small-dev", - type=str2bool, - default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info(f"About to get train {self.args.subset} cuts") - if self.args.subset == "XL": - filenames = glob.glob( - f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" - ) - pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) - idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) - sorted_filenames = [f[1] for f in idx_filenames] - logging.info( - f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" - ) - - cuts_train = lhotse.combine( - lhotse.load_manifest_lazy(p) for p in sorted_filenames - ) - else: - path = ( - self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" - ) - cuts_train = CutSet.from_jsonl_lazy(path) - return cuts_train - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" - ) - if self.args.small_dev: - return cuts_valid.subset(first=1000) - else: - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" - ) - - @lru_cache() - def fsc_train_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz" - ) - - @lru_cache() - def fsc_valid_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands valid cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def fsc_test_small_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands small test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz" - ) - - @lru_cache() - def fsc_test_large_cuts(self) -> CutSet: - logging.info("About to get fluent speech commands large test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz" - ) diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py new file mode 120000 index 000000000..75bc3d45a --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../../ASR/zipformer/asr_datamodule.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py index 149b8bed0..9d1c36466 100755 --- a/egs/gigaspeech/KWS/zipformer/decode-asr.py +++ b/egs/gigaspeech/KWS/zipformer/decode-asr.py @@ -1001,7 +1001,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index a7ba56127..91ed7c093 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -183,7 +183,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 39d8fc6cd..2d88b6e55 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -961,7 +962,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1320,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 0f3f1c1ab..c82b910bb 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -938,7 +938,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index bf50bf5ea..63a38a4cc 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -77,7 +77,13 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -805,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1196,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/ctc_decode.py b/egs/ksponspeech/ASR/zipformer/ctc_decode.py index 30bf1610b..10239db5e 100755 --- a/egs/ksponspeech/ASR/zipformer/ctc_decode.py +++ b/egs/ksponspeech/ASR/zipformer/ctc_decode.py @@ -666,7 +666,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -705,7 +705,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/ksponspeech/ASR/zipformer/decode.py b/egs/ksponspeech/ASR/zipformer/decode.py index 5c21abb79..ba0383010 100755 --- a/egs/ksponspeech/ASR/zipformer/decode.py +++ b/egs/ksponspeech/ASR/zipformer/decode.py @@ -989,7 +989,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 485ea69c9..406749f22 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -92,6 +92,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -942,7 +943,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1333,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libricss/SURT/dprnn_zipformer/pretrained.py b/egs/libricss/SURT/dprnn_zipformer/pretrained.py index 5f9468957..73468417a 100755 --- a/egs/libricss/SURT/dprnn_zipformer/pretrained.py +++ b/egs/libricss/SURT/dprnn_zipformer/pretrained.py @@ -177,7 +177,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 148cafd4b..186d4f6fb 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -1286,7 +1286,7 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) model.load_state_dict(init_ckpt["model"], strict=False) if world_size > 1: diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py index 8c37430ec..4d1f3cf02 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py @@ -1175,7 +1175,7 @@ def run(rank, world_size, args): logging.info( f"Initializing model with checkpoint from {params.model_init_ckpt}" ) - init_ckpt = torch.load(params.model_init_ckpt, map_location=device) + init_ckpt = torch.load(params.model_init_ckpt, map_location=device, weights_only=False) model.load_state_dict(init_ckpt["model"], strict=True) if world_size > 1: diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index 366a1459f..110d7b7ba 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -245,7 +245,6 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then done fi - if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then log "Stage 10: Train BPE model for unnormalized text" if [ ! -f data/punc_texts ]; then diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py index 458109a3f..763bb8b51 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py @@ -252,7 +252,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librilight/SSL/zipformer/decode.py b/egs/librilight/SSL/zipformer/decode.py index 95643c5e1..88b67600b 100644 --- a/egs/librilight/SSL/zipformer/decode.py +++ b/egs/librilight/SSL/zipformer/decode.py @@ -960,7 +960,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py index 50dbd5f2d..793725614 100644 --- a/egs/librilight/SSL/zipformer/finetune.py +++ b/egs/librilight/SSL/zipformer/finetune.py @@ -750,7 +750,7 @@ def _to_int_tuple(s: str): def get_encoder_model(params: AttributeDict) -> nn.Module: if hasattr(params, "pretrained_dir"): logging.info(f"Loading {params.pretrained_dir}") - pretrained = torch.load(params.pretrained_dir) + pretrained = torch.load(params.pretrained_dir, weights_only=False) encoder = HubertModel(params) encoder.load_state_dict(pretrained["model"]) else: diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index a1cfe6e75..ea793ce2f 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -32,7 +32,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate @@ -902,7 +902,7 @@ class Swish(torch.nn.Module): """Construct an Swish object.""" def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" + """Return Swish activation function.""" return x * torch.sigmoid(x) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7e0bf5b7b..fc866f83b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -667,7 +667,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -707,7 +709,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 38b60fcb9..5b3a021ad 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -271,7 +271,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -351,7 +351,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -362,7 +364,9 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index 09f1eb000..02ea80a46 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -42,7 +42,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 0b271a51c..349e8f02d 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -774,7 +774,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -814,7 +816,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index c4a13b101..14c132ada 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -420,7 +421,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -629,7 +630,7 @@ def train_one_epoch( scheduler: LRSchedulerType, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -676,7 +677,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -965,7 +966,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index e6327bb5e..cf58fd18d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -868,7 +868,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -907,7 +909,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 19b26361e..f8e3fa43b 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -334,7 +334,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -345,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index a0cdfcf03..e528b2cb8 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -290,7 +290,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -386,7 +386,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -397,7 +399,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index a2f1125ca..64e77f421 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -95,9 +94,11 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -493,7 +494,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -694,7 +695,7 @@ def train_one_epoch( graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -743,7 +744,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1005,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1073,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 53e48eb13..cffe3df28 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -33,7 +33,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index 74f6e73fa..01fcf0685 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -574,7 +574,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + torch.load( + f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False + ) ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -609,7 +611,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index ca21bd6bf..fc33f9512 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1002,7 +1008,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1071,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..b00cc6cc6 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -93,7 +92,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +566,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -727,7 +733,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -772,7 +778,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1001,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1072,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index d19d50ae6..ec39d5b36 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -72,11 +72,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 709b14070..bd25cfa29 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -66,11 +66,11 @@ def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: An FSA representing LG. """ lexicon = Lexicon(lang_dir) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"data/lm/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"data/lm/{lm}.pt") + d = torch.load(f"data/lm/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index d913756a1..82785ad6e 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 856c9d945..8c75eb871 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -750,7 +750,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e7bad7ed8..9f148b348 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 42c3a5d7f..f29d1d9db 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index feb81d500..e23da3b56 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -521,7 +522,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -717,7 +718,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -763,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1023,7 +1024,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1092,7 +1093,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1a724830b..cfbbb334c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -935,7 +935,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..5aafe10af 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index dcff088e2..888f9931e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -241,7 +241,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4fc4fa7f8..1b31b5485 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -90,9 +89,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -560,7 +561,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -772,7 +773,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -848,7 +849,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1176,7 +1177,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1247,7 +1248,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a2b4f9e1a..e25b79e2e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -815,7 +815,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index e39637bd8..619e783b0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -239,7 +239,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 2c1cef3a3..e169b499f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,6 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -82,9 +81,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -551,7 +552,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -747,7 +748,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -793,7 +794,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1068,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1136,7 +1137,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 40dc3260d..cf3dc9adb 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -10,11 +10,11 @@ nj=15 stage=0 stop_stage=5 -# Note: This script just prepare the minimal requirements that needed by a +# Note: This script just prepares the minimal requirements needed by a # transducer training with bpe units. # # If you want to use ngram or nnlm, please continue running prepare_lm.sh after -# you succeed running this script. +# you succeed in running this script. # # This script also contains the steps to generate phone based units, but they # will not run automatically, you can generate the phone based units by diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh index 1792395d8..55a6e021c 100755 --- a/egs/librispeech/ASR/prepare_lm.sh +++ b/egs/librispeech/ASR/prepare_lm.sh @@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -# This script generate Ngram LM / NNLM and related files that needed by decoding. +# This script generates Ngram LM / NNLM and related files needed by decoding. # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index de367c234..69cc59756 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..3b6ce9b89 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -21,7 +21,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 5b595c76c..5850555cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,9 +10,11 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined +from icefall.utils import create_grad_scaler, torch_autocast + # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. @@ -330,14 +332,14 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler(enabled=True) + scaler = create_grad_scaler(enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.cuda.amp.autocast(enabled=True): + with torch_autocast(enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 931341cc4..0611fd8cb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -76,7 +75,14 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + create_grad_scaler, + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -453,7 +459,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -608,7 +614,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -650,7 +656,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -868,7 +874,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -937,7 +943,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2b872f1d5..2af8f3f4c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -68,7 +67,14 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) def add_model_arguments(parser: argparse.ArgumentParser): @@ -496,7 +502,7 @@ def save_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, and training stats to file. @@ -650,7 +656,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -693,7 +699,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -939,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1004,7 +1010,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 3c4500087..6d1da7440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -741,7 +741,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index e06404619..e1b9779a0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -264,7 +264,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 66c84b2a9..f1d16749c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1347,7 +1347,10 @@ def modified_beam_search( ( context_score, new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + _, + ) = context_graph.forward_one_step( + hyp.context_state, new_token, strict_mode=False + ) new_log_prob = topk_log_probs[k] + context_score @@ -2853,7 +2856,10 @@ def modified_beam_search_LODR( ( context_score, new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + _, + ) = context_graph.forward_one_step( + hyp.context_state, new_token, strict_mode=False + ) ys.append(new_token) state_cost = hyp.state_cost.forward_one_step(new_token) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index ab46e233b..85e61ebab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index c57514193..5a4a74ebb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -754,7 +754,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..6a69332aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 6923f4d40..e6ddcab25 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -265,7 +265,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 6c19f2cb0..ce6c89614 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -91,9 +90,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -523,7 +524,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -716,7 +717,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -759,7 +760,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1000,7 +1001,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1067,7 +1068,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 7c62bfa58..18a3792b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -921,7 +921,7 @@ def load_ngram_LM( if pt_file.is_file(): logging.info(f"Loading pre-compiled {pt_file}") - d = torch.load(pt_file, map_location=device) + d = torch.load(pt_file, map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) G = k2.add_epsilon_self_loops(G) G = k2.arc_sort(G) @@ -1101,7 +1101,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale elif params.decoding_method in [ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..fbc4db921 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 05e6a6fba..19143fb5d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -274,7 +274,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index fdafa5a87..50670d1b2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -87,9 +86,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -546,7 +547,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -755,7 +756,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -827,7 +828,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1126,7 +1127,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1195,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5195a4ef6..925c01c7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -913,7 +913,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 875b03f7f..c35f52309 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -548,7 +549,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -744,7 +745,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -789,7 +790,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,7 +1048,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1116,7 +1117,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8bbceec61..968ea4150 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 7a3e63218..404d7a3d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -972,7 +972,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index a9ce75a7b..9e2669379 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,7 +238,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 66dc5f991..6f9f92623 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -84,9 +83,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -571,7 +572,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -768,7 +769,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -814,7 +815,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1147,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 0667e7f61..8c1529500 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..a5d2457f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f033cb9a..35ee74f15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -96,9 +95,11 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, display_and_save_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -519,7 +520,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -736,7 +737,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -781,7 +782,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1039,7 +1040,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1108,7 +1109,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 3bca7db2c..4f3fbaa81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -348,7 +348,9 @@ class CodebookIndexExtractor: num_codebooks=self.params.num_codebooks, codebook_size=256, ) - quantizer.load_state_dict(torch.load(self.quantizer_file_path)) + quantizer.load_state_dict( + torch.load(self.quantizer_file_path, weights_only=False) + ) quantizer.to(self.params.device) return quantizer diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py index 27ef0a244..949a497ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -289,7 +289,7 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index eb8841cc4..048de7bb9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -910,7 +910,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index 7095c3cc8..da1bf17fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -813,7 +813,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index e7546ec45..d3d996b4a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -85,9 +84,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -635,7 +636,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -678,7 +679,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -857,7 +858,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -903,7 +904,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1220,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1319,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index add0e6a18..ed990b689 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -23,7 +23,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 4bf11ac24..fabda3aaa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 30a737061..5a317083c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -28,6 +28,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Embedding as ScaledEmbedding +from icefall.utils import torch_autocast + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod @@ -289,7 +291,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +671,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +869,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..f94da9788 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,10 +85,12 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, filter_uneven_sized_batch, setup_logger, str2bool, symlink_or_copy, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -763,7 +764,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cbde2a2e4..ee05627ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( from torch import Tensor, nn from icefall.dist import get_rank -from icefall.utils import is_jit_tracing, make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask, torch_autocast class Zipformer(EncoderInterface): @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 629bec058..3b181bf23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -633,7 +633,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -672,7 +674,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 7641fa5af..9e16c3fd7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index d1b7eec65..f7dd07f8d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index a6e919e2f..f1ab2a3ec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 323ba2642..a13952dfa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 1e638aa7d..32242c94e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..a26f11c82 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -588,7 +589,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -787,7 +788,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -833,7 +834,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1128,7 +1129,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1228,7 +1229,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index fa7144f0f..3af3ada2c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -624,7 +624,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -663,7 +665,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py index e2f08abc6..233f00236 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py @@ -808,7 +808,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index e497787d3..025b146b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 80604ef4a..70d9841bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 0582b289f..bf0faf9f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -22,7 +22,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, torch_autocast class Transducer(nn.Module): @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index a82f3562b..9ceec5f5a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index b98756a54..431760f9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -362,7 +362,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -373,7 +375,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..5585d74de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,9 +81,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -581,7 +582,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -778,7 +779,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -822,7 +823,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1118,7 +1119,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1217,7 +1218,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 02029c108..aa2fe8e38 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -936,7 +936,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 2de56837e..a4fbd93ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -85,6 +85,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -257,6 +271,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """ Onnx model inputs: @@ -274,6 +289,8 @@ def export_encoder_model_onnx( The filename to save the exported ONNX model. opset_version: The opset version to use. + dynamic_batch: + True to export a model supporting dynamic batch size """ encoder_model.encoder.__class__.forward = ( @@ -379,7 +396,9 @@ def export_encoder_model_onnx( "encoder_out": {0: "N"}, **inputs, **outputs, - }, + } + if dynamic_batch + else {}, ) add_meta_data(filename=encoder_filename, meta_data=meta_data) @@ -389,6 +408,7 @@ def export_decoder_model_onnx( decoder_model: nn.Module, decoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the decoder model to ONNX format. @@ -412,7 +432,7 @@ def export_decoder_model_onnx( """ context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + y = torch.zeros(1, context_size, dtype=torch.int64) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -425,7 +445,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "context_size": str(context_size), @@ -438,6 +460,7 @@ def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the joiner model to ONNX format. The exported joiner model has two inputs: @@ -452,8 +475,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -470,7 +493,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -629,6 +654,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -638,6 +664,7 @@ def main(): decoder, decoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported decoder to {decoder_filename}") @@ -647,37 +674,39 @@ def main(): joiner, joiner_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported joiner to {joiner_filename}") # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py new file mode 100755 index 000000000..cb872cca0 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export_rknn.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +import logging +from pathlib import Path +from typing import List + +from rknn.api import RKNN + +logging.basicConfig(level=logging.WARNING) + +g_platforms = [ + # "rv1103", + # "rv1103b", + # "rv1106", + # "rk2118", + "rk3562", + "rk3566", + "rk3568", + "rk3576", + "rk3588", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--target-platform", + type=str, + required=True, + help=f"Supported values are: {','.join(g_platforms)}", + ) + + parser.add_argument( + "--in-encoder", + type=str, + required=True, + help="Path to the encoder onnx model", + ) + + parser.add_argument( + "--in-decoder", + type=str, + required=True, + help="Path to the decoder onnx model", + ) + + parser.add_argument( + "--in-joiner", + type=str, + required=True, + help="Path to the joiner onnx model", + ) + + parser.add_argument( + "--out-encoder", + type=str, + required=True, + help="Path to the encoder rknn model", + ) + + parser.add_argument( + "--out-decoder", + type=str, + required=True, + help="Path to the decoder rknn model", + ) + + parser.add_argument( + "--out-joiner", + type=str, + required=True, + help="Path to the joiner rknn model", + ) + + return parser + + +def export_rknn(rknn, filename): + ret = rknn.export_rknn(filename) + if ret != 0: + exit("Export rknn model to {filename} failed!") + + +def init_model(filename: str, target_platform: str, custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config(target_platform=target_platform, custom_string=custom_string) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit("Build model {filename} failed!") + + return rknn + + +class MetaData: + def __init__( + self, + model_type: str, + attention_dims: List[int], + encoder_dims: List[int], + T: int, + left_context_len: List[int], + decode_chunk_len: int, + cnn_module_kernels: List[int], + num_encoder_layers: List[int], + context_size: int, + ): + self.model_type = model_type + self.attention_dims = attention_dims + self.encoder_dims = encoder_dims + self.T = T + self.left_context_len = left_context_len + self.decode_chunk_len = decode_chunk_len + self.cnn_module_kernels = cnn_module_kernels + self.num_encoder_layers = num_encoder_layers + self.context_size = context_size + + def __str__(self) -> str: + return self.to_str() + + def to_str(self) -> str: + def to_s(ll): + return ",".join(list(map(str, ll))) + + s = f"model_type={self.model_type}" + s += ";attention_dims=" + to_s(self.attention_dims) + s += ";encoder_dims=" + to_s(self.encoder_dims) + s += ";T=" + str(self.T) + s += ";left_context_len=" + to_s(self.left_context_len) + s += ";decode_chunk_len=" + str(self.decode_chunk_len) + s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels) + s += ";num_encoder_layers=" + to_s(self.num_encoder_layers) + s += ";context_size=" + str(self.context_size) + + assert len(s) < 1024, (s, len(s)) + + return s + + +def get_meta_data(encoder: str, decoder: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m_encoder = onnxruntime.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + m_decoder = onnxruntime.InferenceSession( + decoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + encoder_meta = m_encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) + + # {'attention_dims': '192,192,192,192,192', 'version': '1', + # 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256', + # 'model_author': 'k2-fsa', 'T': '103', + # 'left_context_len': '192,96,48,24,96', + # 'decode_chunk_len': '96', + # 'cnn_module_kernels': '31,31,31,31,31', + # 'num_encoder_layers': '2,2,2,2,2'} + + def to_int_list(s): + return list(map(int, s.split(","))) + + decoder_meta = m_decoder.get_modelmeta().custom_metadata_map + print(decoder_meta) + + model_type = encoder_meta["model_type"] + attention_dims = to_int_list(encoder_meta["attention_dims"]) + encoder_dims = to_int_list(encoder_meta["encoder_dims"]) + T = int(encoder_meta["T"]) + left_context_len = to_int_list(encoder_meta["left_context_len"]) + decode_chunk_len = int(encoder_meta["decode_chunk_len"]) + cnn_module_kernels = to_int_list(encoder_meta["cnn_module_kernels"]) + num_encoder_layers = to_int_list(encoder_meta["num_encoder_layers"]) + context_size = int(decoder_meta["context_size"]) + + return MetaData( + model_type=model_type, + attention_dims=attention_dims, + encoder_dims=encoder_dims, + T=T, + left_context_len=left_context_len, + decode_chunk_len=decode_chunk_len, + cnn_module_kernels=cnn_module_kernels, + num_encoder_layers=num_encoder_layers, + context_size=context_size, + ) + + +class RKNNModel: + def __init__( + self, + encoder: str, + decoder: str, + joiner: str, + target_platform: str, + ): + self.meta = get_meta_data(encoder, decoder) + self.encoder = init_model( + encoder, + custom_string=self.meta.to_str(), + target_platform=target_platform, + ) + self.decoder = init_model(decoder, target_platform=target_platform) + self.joiner = init_model(joiner, target_platform=target_platform) + + def export_rknn(self, encoder, decoder, joiner): + export_rknn(self.encoder, encoder) + export_rknn(self.decoder, decoder) + export_rknn(self.joiner, joiner) + + def release(self): + self.encoder.release() + self.decoder.release() + self.joiner.release() + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + model = RKNNModel( + encoder=args.in_encoder, + decoder=args.in_decoder, + joiner=args.in_joiner, + target_platform=args.target_platform, + ) + print(model.meta) + + model.export_rknn( + encoder=args.out_encoder, + decoder=args.out_decoder, + joiner=args.out_joiner, + ) + + model.release() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 298d1889b..e5e513671 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -132,10 +132,18 @@ class OnnxModel: sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Encoder input==========") + for i in self.encoder.get_inputs(): + print(i) + print("==========Encoder output==========") + for i in self.encoder.get_outputs(): + print(i) + self.init_encoder_states() def init_encoder_states(self, batch_size: int = 1): encoder_meta = self.encoder.get_modelmeta().custom_metadata_map + print(encoder_meta) model_type = encoder_meta["model_type"] assert model_type == "zipformer", model_type @@ -232,6 +240,12 @@ class OnnxModel: sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + print("==========Decoder input==========") + for i in self.decoder.get_inputs(): + print(i) + print("==========Decoder output==========") + for i in self.decoder.get_outputs(): + print(i) decoder_meta = self.decoder.get_modelmeta().custom_metadata_map self.context_size = int(decoder_meta["context_size"]) @@ -247,6 +261,13 @@ class OnnxModel: providers=["CPUExecutionProvider"], ) + print("==========Joiner input==========") + for i in self.joiner.get_inputs(): + print(i) + print("==========Joiner output==========") + for i in self.joiner.get_outputs(): + print(i) + joiner_meta = self.joiner.get_modelmeta().custom_metadata_map self.joiner_dim = int(joiner_meta["joiner_dim"]) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index aa2dd17fb..f98851f50 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py new file mode 100755 index 000000000..f860aba5d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/test_rknn_on_cpu_simulator.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +from pathlib import Path +from typing import List, Tuple + +import kaldi_native_fbank as knf +import numpy as np +import soundfile as sf +from rknn.api import RKNN + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder onnx model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder onnx model", + ) + + parser.add_argument( + "--joiner", + type=str, + required=True, + help="Path to the joiner onnx model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--wav", + type=str, + required=True, + help="Path to test wave", + ) + + return parser + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_features(filename: str, dim: int = 80) -> np.ndarray: + """ + Args: + filename: + Path to an audio file. + Returns: + Return a 2-D float32 tensor of shape (T, dim) containing the features. + """ + wave, sample_rate = load_audio(filename) + if sample_rate != 16000: + import librosa + + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 + + features = [] + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = dim + opts.frame_opts.snip_edges = False + fbank = knf.OnlineFbank(opts) + + fbank.accept_waveform(16000, wave) + tail_paddings = np.zeros(int(0.5 * 16000), dtype=np.float32) + fbank.accept_waveform(16000, tail_paddings) + fbank.input_finished() + for i in range(fbank.num_frames_ready): + f = fbank.get_frame(i) + features.append(f) + + features = np.stack(features, axis=0) + + return features + + +def load_tokens(filename): + tokens = dict() + with open(filename, "r") as f: + for line in f: + t, i = line.split() + tokens[int(i)] = t + return tokens + + +def init_model(filename, target_platform="rk3588", custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config(target_platform=target_platform, custom_string=custom_string) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit("Build model {filename} failed!") + + ret = rknn.init_runtime() + if ret != 0: + exit(f"Failed to init rknn runtime for {filename}") + return rknn + + +class MetaData: + def __init__( + self, + model_type: str, + attention_dims: List[int], + encoder_dims: List[int], + T: int, + left_context_len: List[int], + decode_chunk_len: int, + cnn_module_kernels: List[int], + num_encoder_layers: List[int], + ): + self.model_type = model_type + self.attention_dims = attention_dims + self.encoder_dims = encoder_dims + self.T = T + self.left_context_len = left_context_len + self.decode_chunk_len = decode_chunk_len + self.cnn_module_kernels = cnn_module_kernels + self.num_encoder_layers = num_encoder_layers + + def __str__(self) -> str: + return self.to_str() + + def to_str(self) -> str: + def to_s(ll): + return ",".join(list(map(str, ll))) + + s = f"model_type={self.model_type}" + s += ";attention_dims=" + to_s(self.attention_dims) + s += ";encoder_dims=" + to_s(self.encoder_dims) + s += ";T=" + str(self.T) + s += ";left_context_len=" + to_s(self.left_context_len) + s += ";decode_chunk_len=" + str(self.decode_chunk_len) + s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels) + s += ";num_encoder_layers=" + to_s(self.num_encoder_layers) + + assert len(s) < 1024, (s, len(s)) + + return s + + +def get_meta_data(encoder: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m = onnxruntime.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = m.get_modelmeta().custom_metadata_map + print(meta) + # {'attention_dims': '192,192,192,192,192', 'version': '1', + # 'model_type': 'zipformer', 'encoder_dims': '256,256,256,256,256', + # 'model_author': 'k2-fsa', 'T': '103', + # 'left_context_len': '192,96,48,24,96', + # 'decode_chunk_len': '96', + # 'cnn_module_kernels': '31,31,31,31,31', + # 'num_encoder_layers': '2,2,2,2,2'} + + def to_int_list(s): + return list(map(int, s.split(","))) + + model_type = meta["model_type"] + attention_dims = to_int_list(meta["attention_dims"]) + encoder_dims = to_int_list(meta["encoder_dims"]) + T = int(meta["T"]) + left_context_len = to_int_list(meta["left_context_len"]) + decode_chunk_len = int(meta["decode_chunk_len"]) + cnn_module_kernels = to_int_list(meta["cnn_module_kernels"]) + num_encoder_layers = to_int_list(meta["num_encoder_layers"]) + + return MetaData( + model_type=model_type, + attention_dims=attention_dims, + encoder_dims=encoder_dims, + T=T, + left_context_len=left_context_len, + decode_chunk_len=decode_chunk_len, + cnn_module_kernels=cnn_module_kernels, + num_encoder_layers=num_encoder_layers, + ) + + +class RKNNModel: + def __init__( + self, encoder: str, decoder: str, joiner: str, target_platform="rk3588" + ): + self.meta = get_meta_data(encoder) + self.encoder = init_model(encoder, custom_string=self.meta.to_str()) + self.decoder = init_model(decoder) + self.joiner = init_model(joiner) + + def release(self): + self.encoder.release() + self.decoder.release() + self.joiner.release() + + def get_init_states( + self, + ) -> List[np.ndarray]: + + cached_len = [] + cached_avg = [] + cached_key = [] + cached_val = [] + cached_val2 = [] + cached_conv1 = [] + cached_conv2 = [] + + num_encoder_layers = self.meta.num_encoder_layers + encoder_dims = self.meta.encoder_dims + left_context_len = self.meta.left_context_len + attention_dims = self.meta.attention_dims + cnn_module_kernels = self.meta.cnn_module_kernels + + num_encoders = len(num_encoder_layers) + N = 1 + + for i in range(num_encoders): + cached_len.append(np.zeros((num_encoder_layers[i], N), dtype=np.int64)) + cached_avg.append( + np.zeros((num_encoder_layers[i], N, encoder_dims[i]), dtype=np.float32) + ) + cached_key.append( + np.zeros( + (num_encoder_layers[i], left_context_len[i], N, attention_dims[i]), + dtype=np.float32, + ) + ) + + cached_val.append( + np.zeros( + ( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ), + dtype=np.float32, + ) + ) + cached_val2.append( + np.zeros( + ( + num_encoder_layers[i], + left_context_len[i], + N, + attention_dims[i] // 2, + ), + dtype=np.float32, + ) + ) + cached_conv1.append( + np.zeros( + ( + num_encoder_layers[i], + N, + encoder_dims[i], + cnn_module_kernels[i] - 1, + ), + dtype=np.float32, + ) + ) + cached_conv2.append( + np.zeros( + ( + num_encoder_layers[i], + N, + encoder_dims[i], + cnn_module_kernels[i] - 1, + ), + dtype=np.float32, + ) + ) + + ans = ( + cached_len + + cached_avg + + cached_key + + cached_val + + cached_val2 + + cached_conv1 + + cached_conv2 + ) + # for i, s in enumerate(ans): + # if s.ndim == 4: + # ans[i] = np.transpose(s, (0, 2, 3, 1)) + return ans + + def run_encoder(self, x: np.ndarray, states: List[np.ndarray]): + """ + Args: + x: (T, C), np.float32 + states: A list of states + """ + x = np.expand_dims(x, axis=0) + + out = self.encoder.inference(inputs=[x] + states, data_format="nchw") + # out[0], encoder_out, shape (1, 24, 512) + return out[0], out[1:] + + def run_decoder(self, x: np.ndarray): + """ + Args: + x: (1, context_size), np.int64 + Returns: + Return decoder_out, (1, C), np.float32 + """ + return self.decoder.inference(inputs=[x])[0] + + def run_joiner(self, encoder_out: np.ndarray, decoder_out: np.ndarray): + """ + Args: + encoder_out: (1, encoder_out_dim), np.float32 + decoder_out: (1, decoder_out_dim), np.float32 + Returns: + joiner_out: (1, vocab_size), np.float32 + """ + return self.joiner.inference(inputs=[encoder_out, decoder_out])[0] + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + id2token = load_tokens(args.tokens) + features = compute_features(args.wav) + model = RKNNModel( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + ) + print(model.meta) + + states = model.get_init_states() + + segment = model.meta.T + offset = model.meta.decode_chunk_len + + context_size = 2 + hyp = [0] * context_size + decoder_input = np.array([hyp], dtype=np.int64) + decoder_out = model.run_decoder(decoder_input) + + i = 0 + while True: + if i + segment > features.shape[0]: + break + x = features[i : i + segment] + i += offset + encoder_out, states = model.run_encoder(x, states) + encoder_out = encoder_out.squeeze(0) # (1, T, C) -> (T, C) + + num_frames = encoder_out.shape[0] + for k in range(num_frames): + joiner_out = model.run_joiner(encoder_out[k : k + 1], decoder_out) + joiner_out = joiner_out.squeeze(0) + max_token_id = joiner_out.argmax() + + # assume 0 is the blank id + if max_token_id != 0: + hyp.append(max_token_id) + decoder_input = np.array([hyp[-context_size:]], dtype=np.int64) + decoder_out = model.run_decoder(decoder_input) + print(hyp) + final_hyp = hyp[context_size:] + print(final_hyp) + text = "".join([id2token[i] for i in final_hyp]) + text = text.replace("▁", " ") + print(text) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..4d8a2644d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -82,7 +81,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -597,7 +603,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -764,7 +770,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -810,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1124,7 +1130,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1224,7 +1230,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index c7e45564f..640d72b67 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -44,7 +44,7 @@ from scaling import ( ) from torch import Tensor, nn -from icefall.utils import make_pad_mask, subsequent_chunk_mask +from icefall.utils import make_pad_mask, subsequent_chunk_mask, torch_autocast def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py index 35158ced4..61c1a9663 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode.py @@ -768,7 +768,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py index a4f52ad7f..e95bb3357 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/decode_gigaspeech.py @@ -788,7 +788,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index da5e144c9..4b97575e6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -86,7 +85,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -615,7 +621,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -795,7 +801,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1218,7 +1224,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1326,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e07777c9f..3cad83a0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -747,7 +747,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..e06594c27 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -24,7 +24,7 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class Transducer(nn.Module): @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index c29b8d8c9..693db2beb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..ad14ec9dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,6 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -91,7 +90,14 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -608,7 +614,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -790,7 +796,7 @@ def train_one_epoch( giga_train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, rng: random.Random, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -866,7 +872,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1225,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1321,7 +1327,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index 0b982f4bf..72842cc28 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -69,7 +69,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 1b52aa8b5..283252a46 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -311,8 +311,7 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 92529e06c..db12ab827 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -398,7 +398,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -428,7 +430,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False + ) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index b3dfab64a..4ad7cb016 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -167,13 +167,15 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +183,9 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index cda03b56e..ec700626a 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -589,7 +589,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -628,7 +630,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method == "whole-lattice-rescoring": diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index cc4471e2b..1b329e8f3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -663,7 +663,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 92dea3aa1..4b234a328 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -347,7 +347,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +360,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 5c6956324..9714aa537 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -249,7 +249,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index 7698ada79..a2ea1dd06 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -286,7 +286,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -365,7 +365,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -376,7 +378,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 1bfd071de..368bd20fa 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -72,9 +71,11 @@ from icefall.lexicon import UniqLexicon from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -550,7 +551,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -757,7 +758,7 @@ def train_one_epoch( phone_lexicon: UniqLexicon, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -809,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1093,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1199,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 4d9bbf4b1..06b1c05b9 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -222,7 +222,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 90b722bde..9b11df673 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -35,7 +35,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 3b86e319e..c5c58f140 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -234,7 +234,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 2de4182f1..9f9159cea 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -234,7 +234,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 83094ea51..973205078 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -234,7 +234,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index fe9347b95..e17407c5f 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -947,7 +947,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -987,7 +989,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.decoding_method in [ diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cbfb3728e..6462d22f8 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -1013,7 +1013,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py index 3cda337c0..a4da83949 100755 --- a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py @@ -1049,7 +1049,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py index 99685f2fe..413b5bb1e 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-ctc.py @@ -153,6 +153,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -176,6 +183,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + class OnnxModel(nn.Module): """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" @@ -407,7 +423,7 @@ def main(): opset_version = 13 logging.info("Exporting ctc model") - filename = params.exp_dir / f"model.onnx" + filename = params.exp_dir / "model.onnx" export_ctc_model_onnx( model, filename, @@ -420,7 +436,7 @@ def main(): logging.info("Generate int8 quantization models") - filename_int8 = params.exp_dir / f"model.int8.onnx" + filename_int8 = params.exp_dir / "model.int8.onnx" quantize_dynamic( model_input=filename, model_output=filename_int8, @@ -428,6 +444,10 @@ def main(): weight_type=QuantType.QInt8, ) + if params.fp16: + filename_fp16 = params.exp_dir / "model.fp16.onnx" + export_onnx_fp16(filename, filename_fp16) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py index 1eba6093b..9a715eefd 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -74,6 +74,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -136,12 +150,35 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--use-whisper-features", + type=str2bool, + default=False, + help="True to use whisper features. Must match the one used in training", + ) + + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + + parser.add_argument( + "--use-external-data", + type=str2bool, + default=False, + help="Set it to true for model file size > 2GB", + ) + add_model_arguments(parser) return parser -def add_meta_data(filename: str, meta_data: Dict[str, str]): +def add_meta_data( + filename: str, meta_data: Dict[str, str], use_external_data: bool = False +): """Add meta data to an ONNX model. It is changed in-place. Args: @@ -150,13 +187,46 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): meta_data: Key-value pairs. """ + filename = str(filename) + model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = value - onnx.save(model, filename) + if use_external_data: + # For models file size > 2GB + external_filename = Path(filename).stem + + onnx.save( + model, + filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_filename + ".weights", + ) + else: + onnx.save(model, filename) + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path + + onnx_fp16_model = convert_float_to_float16_model_path( + onnx_fp32_path, keep_io_types=True + ) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) class OnnxModel(nn.Module): @@ -270,6 +340,9 @@ def export_streaming_ctc_model_onnx( model: OnnxModel, encoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, + use_whisper_features: bool = False, + use_external_data: bool = False, ) -> None: model.encoder.__class__.forward = model.encoder.__class__.streaming_forward @@ -367,6 +440,10 @@ def export_streaming_ctc_model_onnx( "value_head_dims": value_head_dims, "num_heads": num_heads, } + + if use_whisper_features: + meta_data["feature"] = "whisper" + logging.info(f"meta_data: {meta_data}") for i in range(len(init_state[:-2]) // 6): @@ -408,10 +485,16 @@ def export_streaming_ctc_model_onnx( "log_probs": {0: "N"}, **inputs, **outputs, - }, + } + if dynamic_batch + else {}, ) - add_meta_data(filename=encoder_filename, meta_data=meta_data) + add_meta_data( + filename=encoder_filename, + meta_data=meta_data, + use_external_data=use_external_data, + ) @torch.no_grad() @@ -542,20 +625,33 @@ def main(): opset_version = 13 logging.info("Exporting model") - model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + + if params.use_external_data: + model_filename = f"ctc-{suffix}.onnx" + else: + model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + export_streaming_ctc_model_onnx( model, - model_filename, + str(model_filename), opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, + use_whisper_features=params.use_whisper_features, + use_external_data=params.use_external_data, ) logging.info(f"Exported model to {model_filename}") - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + if params.enable_int8_quantization: + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + logging.info("Generate int8 quantization models") + + if params.use_external_data: + model_filename_int8 = f"ctc-{suffix}.int8.onnx" + else: + model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" - model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" quantize_dynamic( model_input=model_filename, model_output=model_filename_int8, @@ -563,6 +659,14 @@ def main(): weight_type=QuantType.QInt8, ) + if params.fp16: + if params.use_external_data: + model_filename_fp16 = f"ctc-{suffix}.fp16.onnx" + export_onnx_fp16_large_2gb(model_filename, model_filename_fp16) + else: + model_filename_fp16 = params.exp_dir / f"ctc-{suffix}.fp16.onnx" + export_onnx_fp16(model_filename, model_filename_fp16) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index a35eb5287..daeb86f6a 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -93,6 +93,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--dynamic-batch", + type=int, + default=1, + help="1 to support dynamic batch size. 0 to support only batch size == 1", + ) + + parser.add_argument( + "--enable-int8-quantization", + type=int, + default=1, + help="1 to also export int8 onnx models.", + ) + parser.add_argument( "--epoch", type=int, @@ -162,12 +176,47 @@ def get_parser(): help="Whether to export models in fp16", ) + parser.add_argument( + "--use-whisper-features", + type=str2bool, + default=False, + help="True to use whisper features. Must match the one used in training", + ) + + parser.add_argument( + "--use-external-data", + type=str2bool, + default=False, + help="Set it to true for model file size > 2GB", + ) + add_model_arguments(parser) return parser -def add_meta_data(filename: str, meta_data: Dict[str, str]): +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16_model_path + + onnx_fp16_model = convert_float_to_float16_model_path( + onnx_fp32_path, keep_io_types=True + ) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def add_meta_data( + filename: str, meta_data: Dict[str, str], use_external_data: bool = False +): """Add meta data to an ONNX model. It is changed in-place. Args: @@ -182,7 +231,19 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): meta.key = key meta.value = value - onnx.save(model, filename) + if use_external_data: + # For models file size > 2GB + external_filename = Path(filename).stem + + onnx.save( + model, + filename, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_filename + ".weights", + ) + else: + onnx.save(model, filename) class OnnxEncoder(nn.Module): @@ -342,6 +403,9 @@ def export_encoder_model_onnx( encoder_filename: str, opset_version: int = 11, feature_dim: int = 80, + dynamic_batch: bool = True, + use_whisper_features: bool = False, + use_external_data: bool = False, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -441,6 +505,9 @@ def export_encoder_model_onnx( "value_head_dims": value_head_dims, "num_heads": num_heads, } + if use_whisper_features: + meta_data["feature"] = "whisper" + logging.info(f"meta_data: {meta_data}") for i in range(len(init_state[:-2]) // 6): @@ -482,16 +549,23 @@ def export_encoder_model_onnx( "encoder_out": {0: "N"}, **inputs, **outputs, - }, + } + if dynamic_batch + else {}, ) - add_meta_data(filename=encoder_filename, meta_data=meta_data) + add_meta_data( + filename=encoder_filename, + meta_data=meta_data, + use_external_data=use_external_data, + ) def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the decoder model to ONNX format. @@ -514,7 +588,7 @@ def export_decoder_model_onnx( context_size = decoder_model.decoder.context_size vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) + y = torch.zeros(1, context_size, dtype=torch.int64) decoder_model = torch.jit.script(decoder_model) torch.onnx.export( decoder_model, @@ -527,7 +601,9 @@ def export_decoder_model_onnx( dynamic_axes={ "y": {0: "N"}, "decoder_out": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { @@ -541,6 +617,7 @@ def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, opset_version: int = 11, + dynamic_batch: bool = True, ) -> None: """Export the joiner model to ONNX format. The exported joiner model has two inputs: @@ -555,8 +632,8 @@ def export_joiner_model_onnx( joiner_dim = joiner_model.output_linear.weight.shape[1] logging.info(f"joiner dim: {joiner_dim}") - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) torch.onnx.export( joiner_model, @@ -573,7 +650,9 @@ def export_joiner_model_onnx( "encoder_out": {0: "N"}, "decoder_out": {0: "N"}, "logit": {0: "N"}, - }, + } + if dynamic_batch + else {}, ) meta_data = { "joiner_dim": str(joiner_dim), @@ -728,12 +807,18 @@ def main(): opset_version = 13 logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + if params.use_external_data: + encoder_filename = f"encoder-{suffix}.onnx" + else: + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" export_encoder_model_onnx( encoder, - encoder_filename, + str(encoder_filename), opset_version=opset_version, feature_dim=params.feature_dim, + dynamic_batch=params.dynamic_batch == 1, + use_whisper_features=params.use_whisper_features, + use_external_data=params.use_external_data, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -743,6 +828,7 @@ def main(): decoder, decoder_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported decoder to {decoder_filename}") @@ -752,57 +838,59 @@ def main(): joiner, joiner_filename, opset_version=opset_version, + dynamic_batch=params.dynamic_batch == 1, ) logging.info(f"Exported joiner to {joiner_filename}") if params.fp16: - from onnxconverter_common import float16 - logging.info("Generate fp16 models") - encoder = onnx.load(encoder_filename) - encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) - encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16, encoder_filename_fp16) + if params.use_external_data: + encoder_filename_fp16 = f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16_large_2gb(encoder_filename, encoder_filename_fp16) + else: + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + export_onnx_fp16(encoder_filename, encoder_filename_fp16) - decoder = onnx.load(decoder_filename) - decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16, decoder_filename_fp16) + export_onnx_fp16(decoder_filename, decoder_filename_fp16) - joiner = onnx.load(joiner_filename) - joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16, joiner_filename_fp16) + export_onnx_fp16(joiner_filename, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + if params.enable_int8_quantization: + logging.info("Generate int8 quantization models") - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + if params.use_external_data: + encoder_filename_int8 = f"encoder-{suffix}.int8.onnx" + else: + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index a56a7a3e6..03c7d6f82 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -70,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -182,6 +181,15 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" @@ -595,20 +603,14 @@ def main(): if params.fp16: logging.info("Generate fp16 models") - encoder = onnx.load(encoder_filename) - encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16, encoder_filename_fp16) + export_onnx_fp16(encoder_filename, encoder_filename_fp16) - decoder = onnx.load(decoder_filename) - decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16, decoder_filename_fp16) + export_onnx_fp16(decoder_filename, decoder_filename_fp16) - joiner = onnx.load(joiner_filename) - joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16, joiner_filename_fp16) + export_onnx_fp16(joiner_filename, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/egs/librispeech/ASR/zipformer/export_rknn_ctc_streaming.py b/egs/librispeech/ASR/zipformer/export_rknn_ctc_streaming.py new file mode 100755 index 000000000..4de0a598a --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export_rknn_ctc_streaming.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +import logging +from pathlib import Path +from typing import List + +from rknn.api import RKNN +from test_rknn_on_cpu_simulator_ctc_streaming import RKNNModel + +logging.basicConfig(level=logging.WARNING) + +g_platforms = [ + # "rv1103", + # "rv1103b", + # "rv1106", + # "rk2118", + "rk3562", + "rk3566", + "rk3568", + "rk3576", + "rk3588", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--target-platform", + type=str, + required=True, + help=f"Supported values are: {','.join(g_platforms)}", + ) + + parser.add_argument( + "--in-model", + type=str, + required=True, + help="Path to the onnx model", + ) + + parser.add_argument( + "--out-model", + type=str, + required=True, + help="Path to the rknn model", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + model = RKNNModel( + model=args.in_model, + target_platform=args.target_platform, + ) + print(model.meta) + + model.export_rknn( + model=args.out_model, + ) + + model.release() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py b/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py new file mode 100755 index 000000000..27ff81b91 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export_rknn_transducer_streaming.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +import logging +from pathlib import Path +from typing import List + +from rknn.api import RKNN +from test_rknn_on_cpu_simulator_ctc_streaming import ( + MetaData, + get_meta_data, + init_model, + export_rknn, +) + +logging.basicConfig(level=logging.WARNING) + +g_platforms = [ + # "rv1103", + # "rv1103b", + # "rv1106", + # "rk2118", + "rk3562", + "rk3566", + "rk3568", + "rk3576", + "rk3588", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--target-platform", + type=str, + required=True, + help=f"Supported values are: {','.join(g_platforms)}", + ) + + parser.add_argument( + "--in-encoder", + type=str, + required=True, + help="Path to the encoder onnx model", + ) + + parser.add_argument( + "--in-decoder", + type=str, + required=True, + help="Path to the decoder onnx model", + ) + + parser.add_argument( + "--in-joiner", + type=str, + required=True, + help="Path to the joiner onnx model", + ) + + parser.add_argument( + "--out-encoder", + type=str, + required=True, + help="Path to the encoder rknn model", + ) + + parser.add_argument( + "--out-decoder", + type=str, + required=True, + help="Path to the decoder rknn model", + ) + + parser.add_argument( + "--out-joiner", + type=str, + required=True, + help="Path to the joiner rknn model", + ) + + return parser + + +class RKNNModel: + def __init__( + self, + encoder: str, + decoder: str, + joiner: str, + target_platform: str, + ): + self.meta = get_meta_data(encoder) + self.encoder = init_model( + encoder, + custom_string=self.meta.to_str(), + target_platform=target_platform, + ) + self.decoder = init_model(decoder, target_platform=target_platform) + self.joiner = init_model(joiner, target_platform=target_platform) + + def export_rknn(self, encoder, decoder, joiner): + export_rknn(self.encoder, encoder) + export_rknn(self.decoder, decoder) + export_rknn(self.joiner, joiner) + + def release(self): + self.encoder.release() + self.decoder.release() + self.joiner.release() + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + model = RKNNModel( + encoder=args.in_encoder, + decoder=args.in_decoder, + joiner=args.in_joiner, + target_platform=args.target_platform, + ) + print(model.meta) + + model.export_rknn( + encoder=args.out_encoder, + decoder=args.out_decoder, + joiner=args.out_joiner, + ) + + model.release() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2ff631914..94e8b273a 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -95,11 +94,13 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + create_grad_scaler, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -140,8 +141,8 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): type=str2bool, default=False, help=""" - Whether to adapt. If true, we will mix 5% of the new data - with 95% of the original data to fine-tune. This is useful + Whether to adapt. If true, we will mix 5%% of the new data + with 95%% of the original data to fine-tune. This is useful if you want to maintain the performance on the original domain """, ) @@ -765,7 +766,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -808,7 +809,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -985,7 +986,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1049,7 +1050,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1134,7 +1135,7 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + f"lr: {cur_lr: .2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -1373,7 +1374,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1474,7 +1475,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index fcd07ae34..d1978df52 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -346,7 +346,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -357,7 +359,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index c7dbe1e0a..6ef250819 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -25,7 +25,7 @@ from encoder_interface import EncoderInterface from lhotse.dataset import SpecAugment from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask, time_warp +from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast class AsrModel(nn.Module): @@ -210,10 +210,10 @@ class AsrModel(nn.Module): ) # Compute consistency regularization loss - exchanged_targets = ctc_output.detach().chunk(2, dim=0) - exchanged_targets = torch.cat( - [exchanged_targets[1], exchanged_targets[0]], dim=0 - ) # exchange: [x1, x2] -> [x2, x1] + batch_size = ctc_output.shape[0] + assert batch_size % 2 == 0, batch_size + # exchange: [x1, x2] -> [x2, x1] + exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0) cr_loss = nn.functional.kl_div( input=ctc_output, target=exchanged_targets, @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8434fab13..8a1764651 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -121,6 +121,139 @@ class BatchedOptimizer(Optimizer): p.copy_(stacked_params[i]) +def basic_step(group, p, state, grad): + # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. + lr = group["lr"] + if p.numel() == p.shape[0]: + lr = lr * group["scalar_lr_scale"] + beta2 = group["betas"][1] + eps = group["eps"] + # p shape: (batch_size,) or (batch_size, 1, [1,..]) + try: + exp_avg_sq = state[ + "exp_avg_sq" + ] # shape: (batch_size,) or (batch_size, 1, [1,..]) + except KeyError: + exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["exp_avg_sq"] = exp_avg_sq + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt().add_(eps) + + return -lr * grad / denom + + +def scaling_step(group, p, state, grad): + delta = basic_step(group, p, state, grad) + if p.numel() == p.shape[0]: + return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) + + step = state["step"] + size_update_period = group["size_update_period"] + + try: + param_rms = state["param_rms"] + scale_grads = state["scale_grads"] + scale_exp_avg_sq = state["scale_exp_avg_sq"] + except KeyError: + # we know p.ndim > 1 because we'd have returned above if not, so don't worry + # about the speial case of dim=[] that pytorch treats inconsistently. + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = param_rms.to(torch.float) + scale_exp_avg_sq = torch.zeros_like(param_rms) + scale_grads = torch.zeros( + size_update_period, *param_rms.shape, dtype=torch.float, device=p.device + ) + state["param_rms"] = param_rms + state["scale_grads"] = scale_grads + state["scale_exp_avg_sq"] = scale_exp_avg_sq + + # on every step, update the gradient w.r.t. the scale of the parameter, we + # store these as a batch and periodically update the size (for speed only, to + # avoid too many operations). + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + + # periodically recompute the value of param_rms. + if step % size_update_period == size_update_period - 1: + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) + + param_min_rms = group["param_min_rms"] + + # scale the step size by param_rms. This is the most important "scaling" part of + # ScaledAdam + delta *= param_rms.clamp(min=param_min_rms) + + if step % size_update_period == size_update_period - 1 and step > 0: + # This block updates the size of parameter by adding a step ("delta") value in + # the direction of either shrinking or growing it. + beta2 = group["betas"][1] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + batch_size = p.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # The following may help prevent instability: don't allow the scale step to be too large in + # either direction. + scale_step.clamp_(min=-0.1, max=0.1) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta.add_(p * scale_step) + + return delta + + +def momentum_step(group, p, state, grad): + delta = scaling_step(group, p, state, grad) + beta1 = group["betas"][0] + try: + stored_delta = state["delta"] + except KeyError: + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + stored_delta.mul_(beta1) + stored_delta.add_(delta, alpha=(1 - beta1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just + # an edge effect that affects the first 10 or so batches; and the effect of not doing it + # is just to do a slower update for the first few batches, which will help stability. + return stored_delta + + class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -352,58 +485,26 @@ class ScaledAdam(BatchedOptimizer): raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - self._step_one_batch(group, p, state, clipping_scale) + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + grad = ( + p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale) + ) + p += momentum_step(group, p.detach(), state, grad) + + if p.numel() == p.shape[0]: # scalar parameter + scalar_max = group["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) + + state["step"] = cur_step + 1 return loss - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: @@ -484,7 +585,7 @@ class ScaledAdam(BatchedOptimizer): ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.warn( + logging.warning( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) @@ -499,8 +600,8 @@ class ScaledAdam(BatchedOptimizer): ans = 0.0 if ans < 1.0: first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( + if ans < 0.5: + logging.warning( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) if self.show_dominant_parameters: @@ -508,6 +609,7 @@ class ScaledAdam(BatchedOptimizer): self._show_gradient_dominating_parameter( tuples, tot_sumsq, group["scalar_lr_scale"] ) + self._show_param_with_unusual_grad(tuples) if ans == 0.0: for (p, state, param_names) in tuples: @@ -515,6 +617,55 @@ class ScaledAdam(BatchedOptimizer): return ans + def _show_param_with_unusual_grad( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + ): + """ + Print information about parameter which has the largest ratio of grad-on-this-batch + divided by normal grad size. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + largest_ratio = 0.0 + largest_name = "" + # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor) + ratios_names = [] + for (p, state, batch_param_names) in tuples: + dims = list(range(1, p.ndim)) + + def mean(x): + # workaround for bad interface of torch's "mean" for when dims is the empty list. + if len(dims) > 0: + return x.mean(dim=dims) + else: + return x + + grad_ratio = ( + (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims)) + .sqrt() + .to("cpu") + ) + + ratios_names += zip( + grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0) + ) + + ratios_names = sorted(ratios_names, reverse=True) + ratios_names = ratios_names[:10] + ratios_names = [ + (ratio, name, largest_index(tensor)) + for (ratio, name, tensor) in ratios_names + ] + + logging.warning( + f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}" + ) + def _show_gradient_dominating_parameter( self, tuples: List[Tuple[Tensor, dict, List[str]]], @@ -572,7 +723,7 @@ class ScaledAdam(BatchedOptimizer): dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.warn( + logging.warning( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -581,182 +732,11 @@ class ScaledAdam(BatchedOptimizer): f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - grad = p.grad - if clipping_scale != 1.0: - grad *= clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) +def largest_index(x: Tensor): + x = x.contiguous() + argmax = x.abs().argmax().item() + return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)] class LRScheduler(object): @@ -787,9 +767,9 @@ class LRScheduler(object): is not the optimizer. """ return { - # the user might try to override the base_lr, so don't include this in the state. - # previously they were included. - # "base_lrs": self.base_lrs, + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } @@ -807,7 +787,6 @@ class LRScheduler(object): self.__dict__.update(state_dict) self.base_lrs = base_lrs - def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr @@ -853,7 +832,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.warn( + logging.warning( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -1184,7 +1163,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 9f3571b08..65ea7c7f2 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -289,7 +289,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 4341ef61f..90a6ff5b8 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -305,7 +305,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -389,7 +389,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -400,7 +402,9 @@ def main(): "whole-lattice-rescoring", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict( + torch.load(params.G, map_location="cpu", weights_only=False) + ) G = G.to(device) if params.method == "whole-lattice-rescoring": # Add epsilon self-loops to G as we will compose diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c2931..22aa1b1ca 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -26,6 +26,8 @@ import torch.nn as nn from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -160,8 +162,10 @@ class PiecewiseLinear(object): extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] + + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( PiecewiseLinear(*zip(x_vals, y_vals1)), PiecewiseLinear(*zip(x_vals, y_vals2)), @@ -306,7 +310,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -759,7 +763,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1014,7 +1018,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1353,7 +1357,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1401,9 +1405,9 @@ class SwooshL(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: - return k2.swoosh_l_forward(x) + return k2.swoosh_l_forward(x).to(x.dtype) else: - return k2.swoosh_l(x) + return k2.swoosh_l(x).to(x.dtype) # return SwooshLFunction.apply(x) @@ -1430,7 +1434,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1475,9 +1479,9 @@ class SwooshR(torch.nn.Module): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: - return k2.swoosh_r_forward(x) + return k2.swoosh_r_forward(x).to(x.dtype) else: - return k2.swoosh_r(x) + return k2.swoosh_r(x).to(x.dtype) # return SwooshRFunction.apply(x) diff --git a/egs/librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py b/egs/librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py new file mode 100755 index 000000000..458508b89 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +from pathlib import Path +from typing import List, Tuple + +import kaldi_native_fbank as knf +import numpy as np +import soundfile as sf +from rknn.api import RKNN + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the onnx model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--wav", + type=str, + required=True, + help="Path to test wave", + ) + + return parser + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_features(filename: str, dim: int = 80) -> np.ndarray: + """ + Args: + filename: + Path to an audio file. + Returns: + Return a 2-D float32 tensor of shape (T, dim) containing the features. + """ + wave, sample_rate = load_audio(filename) + if sample_rate != 16000: + import librosa + + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 + + features = [] + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.mel_opts.num_bins = dim + opts.frame_opts.snip_edges = False + fbank = knf.OnlineFbank(opts) + + fbank.accept_waveform(16000, wave) + tail_paddings = np.zeros(int(0.5 * 16000), dtype=np.float32) + fbank.accept_waveform(16000, tail_paddings) + fbank.input_finished() + for i in range(fbank.num_frames_ready): + f = fbank.get_frame(i) + features.append(f) + + features = np.stack(features, axis=0) + + return features + + +def load_tokens(filename): + tokens = dict() + with open(filename, "r") as f: + for line in f: + t, i = line.split() + tokens[int(i)] = t + return tokens + + +def init_model(filename, target_platform="rk3588", custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config(target_platform=target_platform, custom_string=custom_string) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit("Build model {filename} failed!") + + ret = rknn.init_runtime() + if ret != 0: + exit(f"Failed to init rknn runtime for {filename}") + return rknn + + +class MetaData: + def __init__( + self, + model_type: str, + decode_chunk_len: int, + T: int, + num_encoder_layers: List[int], + encoder_dims: List[int], + cnn_module_kernels: List[int], + left_context_len: List[int], + query_head_dims: List[int], + value_head_dims: List[int], + num_heads: List[int], + ): + self.model_type = model_type + self.decode_chunk_len = decode_chunk_len + self.T = T + self.num_encoder_layers = num_encoder_layers + self.encoder_dims = encoder_dims + self.cnn_module_kernels = cnn_module_kernels + self.left_context_len = left_context_len + self.query_head_dims = query_head_dims + self.value_head_dims = value_head_dims + self.num_heads = num_heads + + def __str__(self) -> str: + return self.to_str() + + def to_str(self) -> str: + def to_s(ll): + return ",".join(list(map(str, ll))) + + s = f"model_type={self.model_type}" + s += ";decode_chunk_len=" + str(self.decode_chunk_len) + s += ";T=" + str(self.T) + s += ";num_encoder_layers=" + to_s(self.num_encoder_layers) + s += ";encoder_dims=" + to_s(self.encoder_dims) + s += ";cnn_module_kernels=" + to_s(self.cnn_module_kernels) + s += ";left_context_len=" + to_s(self.left_context_len) + s += ";query_head_dims=" + to_s(self.query_head_dims) + s += ";value_head_dims=" + to_s(self.value_head_dims) + s += ";num_heads=" + to_s(self.num_heads) + + assert len(s) < 1024, (s, len(s)) + + return s + + +def get_meta_data(model: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m = onnxruntime.InferenceSession( + model, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in m.get_inputs(): + print(i) + + print("-----") + + for i in m.get_outputs(): + print(i) + + meta = m.get_modelmeta().custom_metadata_map + print(meta) + """ + {'num_heads': '4,4,4,8,4,4', 'query_head_dims': '32,32,32,32,32,32', + 'cnn_module_kernels': '31,31,15,15,15,31', + 'num_encoder_layers': '2,2,3,4,3,2', ' version': '1', + 'comment': 'streaming ctc zipformer2', + 'model_type': 'zipformer2', + 'encoder_dims': '192,256,384,512,384,256', + 'model_author': 'k2-fsa', 'T': '77', + 'value_head_dims': '12,12,12,12,12,12', + 'left_context_len': '128,64,32,16,32,64', + 'decode_chunk_len': '64'} + """ + + def to_int_list(s): + return list(map(int, s.split(","))) + + model_type = meta["model_type"] + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + num_encoder_layers = to_int_list(meta["num_encoder_layers"]) + encoder_dims = to_int_list(meta["encoder_dims"]) + cnn_module_kernels = to_int_list(meta["cnn_module_kernels"]) + left_context_len = to_int_list(meta["left_context_len"]) + query_head_dims = to_int_list(meta["query_head_dims"]) + value_head_dims = to_int_list(meta["value_head_dims"]) + num_heads = to_int_list(meta["num_heads"]) + + return MetaData( + model_type=model_type, + decode_chunk_len=decode_chunk_len, + T=T, + num_encoder_layers=num_encoder_layers, + encoder_dims=encoder_dims, + cnn_module_kernels=cnn_module_kernels, + left_context_len=left_context_len, + query_head_dims=query_head_dims, + value_head_dims=value_head_dims, + num_heads=num_heads, + ) + + +def export_rknn(rknn, filename): + ret = rknn.export_rknn(filename) + if ret != 0: + exit("Export rknn model to {filename} failed!") + + +class RKNNModel: + def __init__(self, model: str, target_platform="rk3588"): + self.meta = get_meta_data(model) + self.model = init_model(model, custom_string=self.meta.to_str()) + + def export_rknn(self, model: str): + export_rknn(self.model, model) + + def release(self): + self.model.release() + + def get_init_states( + self, + ) -> List[np.ndarray]: + states = [] + + num_encoder_layers = self.meta.num_encoder_layers + encoder_dims = self.meta.encoder_dims + left_context_len = self.meta.left_context_len + cnn_module_kernels = self.meta.cnn_module_kernels + query_head_dims = self.meta.query_head_dims + value_head_dims = self.meta.value_head_dims + num_heads = self.meta.num_heads + + num_encoders = len(num_encoder_layers) + N = 1 + + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = np.zeros( + (left_context_len[i], N, key_dim), dtype=np.float32 + ) + cached_nonlin_attn = np.zeros( + (1, N, left_context_len[i], nonlin_attn_head_dim), + dtype=np.float32, + ) + cached_val1 = np.zeros( + (left_context_len[i], N, value_dim), + dtype=np.float32, + ) + cached_val2 = np.zeros( + (left_context_len[i], N, value_dim), + dtype=np.float32, + ) + cached_conv1 = np.zeros((N, embed_dim, conv_left_pad), dtype=np.float32) + cached_conv2 = np.zeros((N, embed_dim, conv_left_pad), dtype=np.float32) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = np.zeros((N, 128, 3, 19), dtype=np.float32) + states.append(embed_states) + processed_lens = np.zeros((N,), dtype=np.int64) + states.append(processed_lens) + + return states + + def run_model(self, x: np.ndarray, states: List[np.ndarray]): + """ + Args: + x: (T, C), np.float32 + states: A list of states + """ + x = np.expand_dims(x, axis=0) + + out = self.model.inference(inputs=[x] + states, data_format="nchw") + # out[0]: log_probs, (N, T, C) + return out[0], out[1:] + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + id2token = load_tokens(args.tokens) + features = compute_features(args.wav) + model = RKNNModel( + model=args.model, + ) + print(model.meta) + + states = model.get_init_states() + + segment = model.meta.T + offset = model.meta.decode_chunk_len + + ans = [] + blank = 0 + prev = -1 + i = 0 + while True: + if i + segment > features.shape[0]: + break + x = features[i : i + segment] + i += offset + log_probs, states = model.run_model(x, states) + log_probs = log_probs[0] # (N, T, C) -> (N, T, C) + ids = log_probs.argmax(axis=1) + for k in ids: + if i != blank and i != prev: + ans.append(i) + prev = i + tokens = [id2token[i] for i in ans] + underline = "▁" + # underline = b"\xe2\x96\x81".decode() + text = "".join(tokens).replace(underline, " ").strip() + + print(ans) + print(args.wav) + print(text) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c074c32ec..42ae9b9f2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -98,9 +97,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -829,7 +830,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -1034,7 +1035,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, @@ -1101,9 +1102,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, loss_info = compute_loss( params=params, model=model, @@ -1165,23 +1164,34 @@ def train_one_epoch( rank=rank, ) - if batch_idx % 100 == 0 and params.use_autocast: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. + if params.use_autocast: cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: save_bad_model(suffix="-first-warning") saved_bad_model = True + if not params.inf_check: + register_inf_check_hooks(model) logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: save_bad_model() raise_grad_scale_is_too_small_error(cur_grad_scale) + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + if ( + batch_idx % 25 == 0 + and cur_grad_scale < 2.0 + or batch_idx % 100 == 0 + and cur_grad_scale < 8.0 + or batch_idx % 400 == 0 + and cur_grad_scale < 32.0 + ): + scaler.update(cur_grad_scale * 2.0) + if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 @@ -1335,7 +1345,7 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1438,7 +1448,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1540,9 +1550,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype - ): + with torch_autocast(enabled=params.use_autocast, dtype=params.dtype): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae0129..e83a89400 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -47,6 +47,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1873,7 +1875,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/decode.py b/egs/librispeech/ASR/zipformer_adapter/decode.py index 91533be8d..e8798aed6 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode.py @@ -1005,7 +1005,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py index bbc582f50..66c401761 100755 --- a/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_adapter/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 3511590da..fcd7272e9 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -86,9 +85,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -762,7 +763,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -805,7 +806,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -982,7 +983,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1052,7 +1053,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1397,7 +1398,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1498,7 +1499,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 8e2dfdd72..8bc163db5 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -50,6 +50,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1916,7 +1918,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py index 7f605e2c8..b9eed099c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/decode.py +++ b/egs/librispeech/ASR/zipformer_ctc/decode.py @@ -679,7 +679,9 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load( + f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False + ) ) assert HLG.requires_grad is False @@ -719,7 +721,9 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load( + params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60112a84e..bd3bfa332 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -65,7 +64,14 @@ from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + create_grad_scaler, + setup_logger, + str2bool, + torch_autocast, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] @@ -533,7 +539,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -687,7 +693,7 @@ def train_one_epoch( graph_compiler: BpeCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -726,7 +732,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -987,7 +993,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py index 4d93a905f..acc814a00 100755 --- a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py +++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py @@ -1050,7 +1050,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 3f36f229f..c26a2f5cc 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -96,9 +95,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -775,7 +776,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -818,7 +819,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -995,7 +996,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dls: torch.utils.data.DataLoader, valid_sets: List[str], - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1065,7 +1066,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1407,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1507,7 +1508,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 8d7aa8027..1347570df 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -27,6 +27,8 @@ import torch.nn.functional as F from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from icefall.utils import torch_autocast + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) @@ -307,7 +309,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +865,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1120,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1459,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1536,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 9ab214e86..2b83d58ef 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,6 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -94,9 +93,11 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -707,7 +708,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -883,7 +884,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -947,7 +948,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1253,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1352,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 43865609a..b84b1c32a 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -49,6 +49,8 @@ from scaling import ( ) from torch import Tensor, nn +from icefall.utils import torch_autocast + class Zipformer2(EncoderInterface): """ @@ -1905,7 +1907,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 33c0bf199..bd3ce21f5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -569,7 +569,9 @@ def main(): if params.decoding_method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() @@ -602,7 +604,11 @@ def main(): torch.save(G.as_dict(), params.lang_dir / f"{order}gram.pt") else: logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", + map_location=device, + weights_only=False, + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index 6990c90a0..d5667cafa 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -308,7 +308,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -317,7 +319,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 1e7afc777..ca860b877 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -269,7 +269,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -331,7 +331,9 @@ def main(): if method == "nbest-rescoring-LG": lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") - LG = k2.Fsa.from_dict(torch.load(lg_filename, map_location=device)) + LG = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) LG = k2.Fsa.from_fsas([LG]).to(device) LG.lm_scores = LG.scores.clone() LM = LG @@ -340,7 +342,9 @@ def main(): assert order in ("3", "4") order = int(order) logging.info(f"Loading pre-compiled {order}gram.pt") - d = torch.load(params.lang_dir / f"{order}gram.pt", map_location=device) + d = torch.load( + params.lang_dir / f"{order}gram.pt", map_location=device, weights_only=False + ) G = k2.Fsa.from_dict(d) G.lm_scores = G.scores.clone() LM = G diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1785a328..e0ca0a6a5 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -87,9 +86,11 @@ from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, MetricsTracker, + create_grad_scaler, encode_supervisions, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -514,7 +515,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -696,7 +697,7 @@ def train_one_epoch( mmi_graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, + scaler: "GradScaler", model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -744,7 +745,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1037,7 +1038,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1138,7 +1139,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py index 837061b8c..f13f8dc9a 100644 --- a/egs/librispeech/SSL/hubert/decode.py +++ b/egs/librispeech/SSL/hubert/decode.py @@ -962,7 +962,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/SSL/hubert/decode_ce.py b/egs/librispeech/SSL/hubert/decode_ce.py index a8d8bc9c2..9529ce627 100644 --- a/egs/librispeech/SSL/hubert/decode_ce.py +++ b/egs/librispeech/SSL/hubert/decode_ce.py @@ -962,7 +962,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 17daa3c9d..ea92b3947 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -86,6 +86,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -450,7 +451,7 @@ def _to_int_tuple(s: str): def get_encoder_model(params: AttributeDict) -> nn.Module: if hasattr(params, "pretrained_dir"): logging.info(f"Loading {params.pretrained_dir}") - pretrained = torch.load(params.pretrained_dir) + pretrained = torch.load(params.pretrained_dir, weights_only=False) encoder = HubertModel(params) encoder.load_state_dict(pretrained["model"]) else: @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 2723cc770..1c1dc25a5 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -450,7 +451,7 @@ def _to_int_tuple(s: str): def get_encoder_model(params: AttributeDict) -> nn.Module: if hasattr(params, "pretrained_dir"): logging.info(f"Loading {params.pretrained_dir}") - pretrained = torch.load(params.pretrained_dir) + pretrained = torch.load(params.pretrained_dir, weights_only=False) encoder = HubertModel(params) encoder.load_state_dict(pretrained["model"]) else: @@ -816,7 +817,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1207,7 +1208,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index f183d90fd..240cd2c0d 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 94948695d..12f95c16f 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -80,6 +80,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -644,7 +645,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1036,7 +1037,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py index 4212cd9c6..d048e15e2 100644 --- a/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py +++ b/egs/librispeech/SSL/local/convert_checkpoint_from_fairseq.py @@ -12,7 +12,7 @@ args = parser.parse_args() src = args.src tgt = args.tgt -old_checkpoint = torch.load(src) +old_checkpoint = torch.load(src, weights_only=False) new_checkpoint = OrderedDict() new_checkpoint["model"] = old_checkpoint["model"] torch.save(new_checkpoint, tgt) diff --git a/egs/librispeech/SSL/local/prepare_lang.py b/egs/librispeech/SSL/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100644 --- a/egs/librispeech/SSL/local/prepare_lang.py +++ b/egs/librispeech/SSL/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/librispeech/SSL/zipformer/decode.py b/egs/librispeech/SSL/zipformer/decode.py index 1562c28b8..9f385ea68 100644 --- a/egs/librispeech/SSL/zipformer/decode.py +++ b/egs/librispeech/SSL/zipformer/decode.py @@ -960,7 +960,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index c907b41c5..8b044fbb5 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -81,6 +81,7 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, @@ -749,7 +750,7 @@ def _to_int_tuple(s: str): def get_encoder_model(params: AttributeDict) -> nn.Module: if hasattr(params, "pretrained_dir"): logging.info(f"Loading {params.pretrained_dir}") - pretrained = torch.load(params.pretrained_dir) + pretrained = torch.load(params.pretrained_dir, weights_only=False) encoder = HubertModel(params) encoder.load_state_dict(pretrained["model"]) else: @@ -1115,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1504,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index 46a968b69..2c2077376 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, torch_autocast class AsrModel(nn.Module): @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 937fb382e..d772f56d0 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -78,6 +78,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -944,7 +945,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1334,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357..5071a91a8 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -22,6 +22,7 @@ import math import random import warnings from typing import List, Optional, Tuple, Union +from icefall.utils import torch_autocast import torch from encoder_interface import EncoderInterface @@ -1849,7 +1850,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch_autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py index 3fa045533..822df6722 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -578,7 +578,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py index b6b1cb020..95b57b8e8 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py @@ -457,7 +457,7 @@ def main(): params.num_classes = num_classes - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 82c68803f..19cce1708 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -84,6 +84,7 @@ from icefall.utils import ( get_texts, setup_logger, str2bool, + torch_autocast, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -757,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1076,7 +1077,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index b276d0587..a32183bf7 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -79,6 +79,7 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.otc_phone_graph_compiler import OtcPhoneTrainingGraphCompiler from icefall.utils import ( + torch_autocast, AttributeDict, MetricsTracker, encode_supervisions_otc, @@ -758,7 +759,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1079,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/local/compile_hlg.py b/egs/librispeech/WSASR/local/compile_hlg.py index 63791f4cc..645826974 100755 --- a/egs/librispeech/WSASR/local/compile_hlg.py +++ b/egs/librispeech/WSASR/local/compile_hlg.py @@ -78,11 +78,11 @@ def compile_HLG(lm_dir: str, lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path(f"{lm_dir}/{lm}.pt").is_file(): logging.info(f"Loading pre-compiled {lm}") - d = torch.load(f"{lm_dir}/{lm}.pt") + d = torch.load(f"{lm_dir}/{lm}.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info(f"Loading {lm}.fst.txt") diff --git a/egs/librispeech/WSASR/local/prepare_lang.py b/egs/librispeech/WSASR/local/prepare_lang.py index d913756a1..82785ad6e 100755 --- a/egs/librispeech/WSASR/local/prepare_lang.py +++ b/egs/librispeech/WSASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/librispeech/WSASR/local/prepare_otc_lang.py b/egs/librispeech/WSASR/local/prepare_otc_lang.py index 01865b865..cfd8a18cd 100755 --- a/egs/librispeech/WSASR/local/prepare_otc_lang.py +++ b/egs/librispeech/WSASR/local/prepare_otc_lang.py @@ -29,7 +29,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py index d77aa5962..bd360b74f 100755 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -802,7 +802,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -842,7 +842,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.decoding_method in [ diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py index 759d9d50a..484a3b0a7 100755 --- a/egs/libritts/ASR/zipformer/decode.py +++ b/egs/libritts/ASR/zipformer/decode.py @@ -1014,7 +1014,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md index 4d4fb8580..67424a1ca 100644 --- a/egs/libritts/TTS/README.md +++ b/egs/libritts/TTS/README.md @@ -1,7 +1,7 @@ # Introduction -LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. -The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. The main differences from the LibriSpeech corpus are listed below: 1. The audio files are at 24kHz sampling rate. 2. The speech is split at sentence breaks. @@ -11,16 +11,16 @@ The main differences from the LibriSpeech corpus are listed below: For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. > [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). > While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> +> > By using this framework, you agree to the following: > 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> +> > 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> +> > 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> +> > 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. @@ -49,3 +49,54 @@ To inference, use: --epoch 400 \ --tokens data/tokens.txt ``` + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo). + +Preparation: + +``` +bash prepare.sh --start-stage 4 +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./libritts.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p} +``` diff --git a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 000000000..68579ffd4 --- /dev/null +++ b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py index faeb611f5..cdc39ea6b 100755 --- a/egs/libritts/TTS/local/prepare_tokens_libritts.py +++ b/egs/libritts/TTS/local/prepare_tokens_libritts.py @@ -31,15 +31,6 @@ from piper_phonemize import phonemize_espeak from tqdm.auto import tqdm -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s - - def prepare_tokens_libritts(): output_dir = Path("data/spectrogram") prefix = "libritts" @@ -72,7 +63,7 @@ def prepare_tokens_libritts(): for t in tokens_list: tokens.extend(t) cut.tokens = tokens - cut.supervisions[0].normalized_text = remove_punc_to_upper(text) + cut.supervisions[0].normalized_text = text new_cuts.append(cut) diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh index 44016e6d2..a0a6d2ae1 100755 --- a/egs/libritts/TTS/prepare.sh +++ b/egs/libritts/TTS/prepare.sh @@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then cd vits/monotonic_align python setup.py build_ext --inplace cd ../../ - else + else log "monotonic_align lib already built" fi fi @@ -75,20 +75,20 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Compute Spectrogram for LibriTTS" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.libritts.done ]; then - ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate touch data/spectrogram/.libritts.done fi - # Here we shuffle and combine the train-clean-100, train-clean-360 and + # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + <(gunzip -c data/spectrogram/libritts_cuts_train-other-500.jsonl.gz) | \ shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi - # Here we shuffle and combine the train-clean-100, train-clean-360 + # Here we shuffle and combine the train-clean-100, train-clean-360 # together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ @@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LibriTTS" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.libritts_with_token.done ]; then ./local/prepare_tokens_libritts.py @@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Generate token file" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +audio_feats_dir=data/tokenized +dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Tokenize/Fbank LibriTTS for valle" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --audio-extractor "Encodec" \ + --batch-duration 400 \ + --src-dir "data/manifests" \ + --output-dir "${audio_feats_dir}" + fi + touch ${audio_feats_dir}/.libritts.tokenize.done + + lhotse combine \ + ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz +fi diff --git a/egs/libritts/TTS/valle b/egs/libritts/TTS/valle new file mode 120000 index 000000000..c8fe8fdb0 --- /dev/null +++ b/egs/libritts/TTS/valle @@ -0,0 +1 @@ +../../wenetspeech4tts/TTS/valle/ \ No newline at end of file diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 1cd6e8fd7..f5495eeaf 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -131,12 +131,12 @@ To inference, use: wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -./matcha/inference \ +./matcha/infer.py \ --exp-dir ./matcha/exp-new-3 \ --epoch 4000 \ --tokens ./data/tokens.txt \ --vocoder ./generator_v1 \ - --input-text "how are you doing?" + --input-text "how are you doing?" \ --output-wav ./generated.wav ``` @@ -166,7 +166,7 @@ To export the checkpoint to onnx: --tokens ./data/tokens.txt ``` -The above command generate the following files: +The above command generates the following files: - model-steps-2.onnx - model-steps-3.onnx @@ -176,6 +176,15 @@ The above command generate the following files: where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. +**HINT**: If you get the following error while running `export_onnx.py`: + +``` +torch.onnx.errors.UnsupportedOperatorError: Exporting the operator +'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. +``` + +please use `torch>=2.2.0`. + To export the Hifigan vocoder to onnx, please use: diff --git a/egs/ljspeech/TTS/local/audio.py b/egs/ljspeech/TTS/local/audio.py new file mode 120000 index 000000000..b70d91c92 --- /dev/null +++ b/egs/ljspeech/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 5152ae675..906025b7f 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank. import argparse import logging import os -from dataclasses import dataclass from pathlib import Path -from typing import Union -import numpy as np import torch +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet -from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet -from lhotse.utils import Seconds, compute_num_frames -from matcha.audio import mel_spectrogram from icefall.utils import get_executor -@dataclass -class MyFbankConfig: - n_fft: int - n_mels: int - sampling_rate: int - hop_length: int - win_length: int - f_min: float - f_max: float - - -@register_extractor -class MyFbank(FeatureExtractor): - - name = "MyFbank" - config_type = MyFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: np.ndarray, - sampling_rate: int, - ) -> torch.Tensor: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - samples = torch.from_numpy(samples) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = ( - mel_spectrogram( - samples, - self.config.n_fft, - self.config.n_mels, - self.config.sampling_rate, - self.config.hop_length, - self.config.win_length, - self.config.f_min, - self.config.f_max, - center=False, - ) - .squeeze() - .t() - ) - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - return mel.numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=22050, @@ -158,6 +73,8 @@ def compute_fbank_ljspeech(num_jobs: int): f_min=0, f_max=8000, ) + if not torch.cuda.is_available(): + config.device = "cpu" prefix = "ljspeech" suffix = "jsonl.gz" @@ -170,7 +87,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - extractor = MyFbank(config) + extractor = MatchaFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" diff --git a/egs/ljspeech/TTS/local/fbank.py b/egs/ljspeech/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/ljspeech/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index 9535ba9f4..68159ae03 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -33,7 +33,6 @@ import argparse import logging from pathlib import Path -from compute_fbank_ljspeech import MyFbank from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py deleted file mode 120000 index 85255ba0c..000000000 --- a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py +++ /dev/null @@ -1 +0,0 @@ -../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index 487ea2995..3c653fbf1 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module): self, x: torch.Tensor, x_lengths: torch.Tensor, - temperature: torch.Tensor, + noise_scale: torch.Tensor, length_scale: torch.Tensor, ) -> torch.Tensor: """ Args: : x: (batch_size, num_tokens), torch.int64 x_lengths: (batch_size,), torch.int64 - temperature: (1,), torch.float32 + noise_scale: (1,), torch.float32 length_scale (1,), torch.float32 Returns: audio: (batch_size, num_samples) @@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module): x=x, x_lengths=x_lengths, n_timesteps=self.num_steps, - temperature=temperature, + temperature=noise_scale, length_scale=length_scale, )["mel"] # mel: (batch_size, feat_dim, num_frames) @@ -127,7 +127,6 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size @@ -153,17 +152,17 @@ def main(): # encoder has a large initial length x = torch.ones(1, 1000, dtype=torch.int64) x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) + noise_scale = torch.tensor([1.0]) length_scale = torch.tensor([1.0]) opset_version = 14 filename = f"model-steps-{num_steps}.onnx" torch.onnx.export( wrapper, - (x, x_lengths, temperature, length_scale), + (x, x_lengths, noise_scale, length_scale), filename, opset_version=opset_version, - input_names=["x", "x_length", "temperature", "length_scale"], + input_names=["x", "x_length", "noise_scale", "length_scale"], output_names=["mel"], dynamic_axes={ "x": {0: "N", 1: "L"}, @@ -177,12 +176,16 @@ def main(): "language": "English", "voice": "en-us", "has_espeak": 1, + "jieba": 0, "n_speakers": 1, "sample_rate": 22050, "version": 1, + "pad_id": tokenizer.pad_id, "model_author": "icefall", "maintainer": "k2-fsa", + "use_eos_bos": 1, "dataset": "LJ Speech", + "dataset_url": "https://keithito.com/LJ-Speech-Dataset/", "num_ode_steps": num_steps, } add_meta_data(filename=filename, meta_data=meta_data) diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 63d1fac20..5c96b3bc7 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -7,7 +7,7 @@ from typing import Any, Dict import onnx import torch -from inference import load_vocoder +from infer import load_vocoder def add_meta_data(filename: str, meta_data: Dict[str, Any]): diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py new file mode 100644 index 000000000..cc94a301f --- /dev/null +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from audio import mel_spectrogram +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +@dataclass +class MatchaFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + device: str = "cuda" + + +@register_extractor +class MatchaFbank(FeatureExtractor): + + name = "MatchaFbank" + config_type = MatchaFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples).to(self.device) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.cpu().numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate diff --git a/egs/ljspeech/TTS/matcha/hifigan/xutils.py b/egs/ljspeech/TTS/matcha/hifigan/xutils.py index eefadcb7a..2c0d00823 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/xutils.py +++ b/egs/ljspeech/TTS/matcha/hifigan/xutils.py @@ -41,7 +41,7 @@ def get_padding(kernel_size, dilation=1): def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print(f"Loading '{filepath}'") - checkpoint_dict = torch.load(filepath, map_location=device) + checkpoint_dict = torch.load(filepath, map_location=device, weights_only=False) print("Complete.") return checkpoint_dict diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py new file mode 100755 index 000000000..8ccd35264 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu", weights_only=False)["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py deleted file mode 100755 index 64abd8e50..000000000 --- a/egs/ljspeech/TTS/matcha/inference.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import datetime as dt -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -from matcha.hifigan.config import v1, v2, v3 -from matcha.hifigan.denoiser import Denoiser -from matcha.hifigan.models import Generator as HiFiGAN -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp-new-3", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - return parser - - -def load_vocoder(checkpoint_path): - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform(mel, vocoder, denoiser): - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.cpu().squeeze() - - -def process_text(text: str, tokenizer): - x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesise( - model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None -): - text_processed = process_text(text, tokenizer) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, - ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): - raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.eval() - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - denoiser = Denoiser(vocoder, mode="zeros") - - # Number of ODE Solver steps - n_timesteps = 2 - - # Changes to the speaking rate - length_scale = 1.0 - - # Sampling temperature - temperature = 0.667 - - output = synthesise( - model=model, - tokenizer=tokenizer, - n_timesteps=n_timesteps, - text=params.input_text, - length_scale=length_scale, - temperature=temperature, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 14d19f5d4..102d87713 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat -from matcha.models.components.transformer import BasicTransformerBlock +from models.components.transformer import BasicTransformerBlock class SinusoidalPosEmb(torch.nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 997689b1c..eb795ef32 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -2,7 +2,7 @@ from abc import ABC import torch import torch.nn.functional as F -from matcha.models.components.decoder import Decoder +from models.components.decoder import Decoder class BASECFM(torch.nn.Module, ABC): diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index ca77cba51..364ff1938 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -5,7 +5,7 @@ import math import torch import torch.nn as nn from einops import rearrange -from matcha.model import sequence_mask +from model import sequence_mask class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index 330d1dc47..fe0a72402 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -2,17 +2,17 @@ import datetime as dt import math import random -import matcha.monotonic_align as monotonic_align +import monotonic_align as monotonic_align import torch -from matcha.model import ( +from model import ( denormalize, duration_loss, fix_len_compatibility, generate_path, sequence_mask, ) -from matcha.models.components.flow_matching import CFM -from matcha.models.components.text_encoder import TextEncoder +from models.components.flow_matching import CFM +from models.components.text_encoder import TextEncoder class MatchaTTS(torch.nn.Module): # 🍵 diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore index 28bdad6b8..3def4ae26 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore +++ b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore @@ -1,3 +1,3 @@ build core.c -*.so +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index 5b26fe474..f87ae1bd3 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,8 +1,7 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py import numpy as np import torch -from matcha.monotonic_align.core import maximum_path_c + +from .core import maximum_path_c def maximum_path(value, mask): diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx index eabc7f273..091fcc3a5 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx +++ b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx @@ -1,5 +1,3 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx import numpy as np cimport cython diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py index df26c633e..beacf2e36 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -1,12 +1,30 @@ -# Copied from +# Modified from # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py -from distutils.core import setup - -import numpy from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] setup( name="monotonic_align", - ext_modules=cythonize("core.pyx"), - include_dirs=[numpy.get_include()], + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, ) diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index be34343d3..19e9b49cb 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -8,7 +8,7 @@ import logging import onnxruntime as ort import soundfile as sf import torch -from inference import load_vocoder +from infer import load_vocoder from tokenizer import Tokenizer @@ -89,6 +89,7 @@ class OnnxHifiGANModel: self.model.get_inputs()[0].name: x.numpy(), }, )[0] + # audio: (batch_size, num_samples) return torch.from_numpy(audio) @@ -97,19 +98,24 @@ class OnnxModel: def __init__( self, filename: str, + tokens: str, ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 2 self.session_opts = session_opts - self.tokenizer = Tokenizer("./data/tokens.txt") + self.tokenizer = Tokenizer(tokens) self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + for i in self.model.get_inputs(): print(i) @@ -126,7 +132,7 @@ class OnnxModel: print("x_lengths", x_lengths) print("x", x.shape) - temperature = torch.tensor([1.0], dtype=torch.float32) + noise_scale = torch.tensor([1.0], dtype=torch.float32) length_scale = torch.tensor([1.0], dtype=torch.float32) mel = self.model.run( @@ -134,10 +140,11 @@ class OnnxModel: { self.model.get_inputs()[0].name: x.numpy(), self.model.get_inputs()[1].name: x_lengths.numpy(), - self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: length_scale.numpy(), }, )[0] + # mel: (batch_size, feat_dim, num_frames) return torch.from_numpy(mel) @@ -147,7 +154,7 @@ def main(): params = get_parser().parse_args() logging.info(vars(params)) - model = OnnxModel(params.acoustic_model) + model = OnnxModel(params.acoustic_model, params.tokens) vocoder = OnnxHifiGANModel(params.vocoder) text = params.input_text x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) @@ -164,15 +171,17 @@ def main(): print("audio", audio.shape) # (1, 1, num_samples) audio = audio.squeeze() + sample_rate = model.sample_rate + t = (end_t - start_t).total_seconds() t2 = (end_t2 - start_t2).total_seconds() - rtf_am = t * 22050 / audio.shape[-1] - rtf_vocoder = t2 * 22050 / audio.shape[-1] + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] print("RTF for acoustic model ", rtf_am) print("RTF for vocoder", rtf_vocoder) # skip denoiser - sf.write(params.output_wav, audio, 22050, "PCM_16") + sf.write(params.output_wav, audio, sample_rate, "PCM_16") logging.info(f"Saved to {params.output_wav}") diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt index 5aadc8984..d7829c1e1 100644 --- a/egs/ljspeech/TTS/matcha/requirements.txt +++ b/egs/ljspeech/TTS/matcha/requirements.txt @@ -1,3 +1,4 @@ conformer==0.3.2 diffusers # developed using version ==0.25.0 librosa +einops \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 5e713fdfd..853042413 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -14,9 +14,9 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.model import fix_len_compatibility -from matcha.models.matcha_tts import MatchaTTS -from matcha.tokenizer import Tokenizer +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict: "n_spks": 1, "n_fft": 1024, "n_feats": 80, - "sample_rate": 22050, + "sampling_rate": 22050, "hop_length": 256, "win_length": 1024, "f_min": 0, @@ -445,11 +445,6 @@ def train_one_epoch( saved_bad_model = False - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", @@ -493,9 +488,10 @@ def train_one_epoch( loss = sum(losses.values()) - optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() loss_info = MetricsTracker() loss_info["samples"] = batch_size diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 8e37fc030..1e637b766 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, PrecomputedFeatures, SimpleCutSampler, - SpecAugment, SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -177,7 +176,7 @@ class LJSpeechTtsDataModule: if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -189,7 +188,7 @@ class LJSpeechTtsDataModule: train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) @@ -238,7 +237,7 @@ class LJSpeechTtsDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -250,7 +249,7 @@ class LJSpeechTtsDataModule: validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -282,7 +281,7 @@ class LJSpeechTtsDataModule: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -294,7 +293,7 @@ class LJSpeechTtsDataModule: test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 6f16f8d47..ec5062933 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -25,26 +25,16 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for vits already built" - fi - - if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then - pushd matcha/monotonic_align - python3 setup.py build - mv -v build/lib.*/matcha/monotonic_align/core.*.so . - rm -rf build - rm core.c - ls -lh - popd - else - log "monotonic_align lib for matcha-tts already built" - fi + log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)" + for recipe in vits matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 7be76e315..cf1067dfc 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -234,7 +234,7 @@ def main(): logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - # we need cut ids to display recognition results. + # we need cut ids to organize tts results. args.return_cuts = True ljspeech = LJSpeechTtsDataModule(args) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/.gitignore b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore new file mode 100644 index 000000000..3def4ae26 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py index 1de10f012..4faaa96a5 100755 --- a/egs/ljspeech/TTS/vits/test_model.py +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -18,7 +18,6 @@ from tokenizer import Tokenizer from train import get_model, get_params -from vits import VITS def test_model_type(model_type): diff --git a/egs/mdcc/ASR/zipformer/decode.py b/egs/mdcc/ASR/zipformer/decode.py index ce104baf7..d2ae26409 100755 --- a/egs/mdcc/ASR/zipformer/decode.py +++ b/egs/mdcc/ASR/zipformer/decode.py @@ -756,7 +756,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py index f771d7f1e..26e470bd7 100755 --- a/egs/mgb2/ASR/conformer_ctc/decode.py +++ b/egs/mgb2/ASR/conformer_ctc/decode.py @@ -575,7 +575,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -614,7 +614,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py index 0ab2af527..8a3655bf6 100755 --- a/egs/mgb2/ASR/conformer_ctc/pretrained.py +++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py @@ -275,7 +275,7 @@ def main(): use_feat_batchnorm=params.use_feat_batchnorm, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -347,7 +347,7 @@ def main(): "attention-decoder", ]: logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -358,7 +358,7 @@ def main(): "attention-decoder", ]: logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py index 81a16f0ff..639099f8a 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py @@ -236,7 +236,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/multi_ja_en/ASR/README.md b/egs/multi_ja_en/ASR/README.md new file mode 100644 index 000000000..09964a4ab --- /dev/null +++ b/egs/multi_ja_en/ASR/README.md @@ -0,0 +1,17 @@ +# Introduction + +A bilingual Japanese-English ASR model that utilizes ReazonSpeech, developed by the developers of ReazonSpeech. + +**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio. + + +# Included Training Sets + +1. LibriSpeech (English) +2. ReazonSpeech (Japanese) + +|Datset| Number of hours| URL| +|---|---:|---| +|**TOTAL**|35,960|---| +|LibriSpeech|960|https://www.openslr.org/12/| +|ReazonSpeech (all) |35,000|https://huggingface.co/datasets/reazon-research/reazonspeech| diff --git a/egs/multi_ja_en/ASR/RESULTS.md b/egs/multi_ja_en/ASR/RESULTS.md new file mode 100644 index 000000000..0f6996013 --- /dev/null +++ b/egs/multi_ja_en/ASR/RESULTS.md @@ -0,0 +1,52 @@ +## Results + +### Zipformer + +#### Non-streaming + +The training command is: + +```shell +./zipformer/train.py \ + --bilingual 1 \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 600 +``` + +The decoding command is: + +```shell +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search +``` + +To export the model with onnx: + +```shell +./zipformer/export-onnx.py --tokens data/lang_bbpe_2000/tokens.txt --use-averaged-model 0 --epoch 35 --avg 1 --exp-dir zipformer/exp --num-encoder-layers "2,2,3,4,3,2" --downsampling-factor "1,2,4,8,4,2" --feedforward-dim "512,768,1024,1536,1024,768" --num-heads "4,4,4,8,4,4" --encoder-dim "192,256,384,512,384,256" --query-head-dim 32 --value-head-dim 12 --pos-head-dim 4 --pos-dim 48 --encoder-unmasked-dim "192,192,256,256,256,192" --cnn-module-kernel "31,31,15,15,15,31" --decoder-dim 512 --joiner-dim 512 --causal False --chunk-size "16,32,64,-1" --left-context-frames "64,128,256,-1" --fp16 True +``` +Word Error Rates (WERs) listed below: + +| Datasets | ReazonSpeech | ReazonSpeech | LibriSpeech | LibriSpeech | +|----------------------|--------------|---------------|--------------------|-------------------| +| Zipformer WER (%) | dev | test | test-clean | test-other | +| greedy_search | 5.9 | 4.07 | 3.46 | 8.35 | +| modified_beam_search | 4.87 | 3.61 | 3.28 | 8.07 | + + +Character Error Rates (CERs) for Japanese listed below: +| Decoding Method | In-Distribution CER | JSUT | CommonVoice | TEDx | +| :------------------: | :-----------------: | :--: | :---------: | :---: | +| greedy search | 12.56 | 6.93 | 9.75 | 9.67 | +| modified beam search | 11.59 | 6.97 | 9.55 | 9.51 | + +Pre-trained model can be found here: https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/main + diff --git a/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py b/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py new file mode 100644 index 000000000..af7841406 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Tuple + +import torch + +# fmt: off +from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527 + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + RecordingSet, + SupervisionSet, +) + +# fmt: on + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +RNG_SEED = 42 +concat_params = {"gap": 1.0, "maxlen": 10.0} + + +def make_cutset_blueprints( + manifest_dir: Path, +) -> List[Tuple[str, CutSet]]: + cut_sets = [] + + # Create test dataset + logging.info("Creating test cuts.") + cut_sets.append( + ( + "test", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_test.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_test.jsonl.gz" + ), + ), + ) + ) + + # Create dev dataset + logging.info("Creating dev cuts.") + cut_sets.append( + ( + "dev", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_dev.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz" + ), + ), + ) + ) + + # Create train dataset + logging.info("Creating train cuts.") + cut_sets.append( + ( + "train", + CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "reazonspeech_recordings_train.jsonl.gz" + ), + supervisions=SupervisionSet.from_file( + manifest_dir / "reazonspeech_supervisions_train.jsonl.gz" + ), + ), + ) + ) + return cut_sets + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", "--manifest-dir", type=Path) + return parser.parse_args() + + +def main(): + args = get_args() + + extractor = Fbank(FbankConfig(num_mel_bins=80)) + num_jobs = min(16, os.cpu_count()) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + if (args.manifest_dir / ".reazonspeech-fbank.done").exists(): + logging.info( + "Previous fbank computed for ReazonSpeech found. " + f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank." + ) + return + else: + cut_sets = make_cutset_blueprints(args.manifest_dir) + for part, cut_set in cut_sets: + logging.info(f"Processing {part}") + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + num_jobs=num_jobs, + storage_path=(args.manifest_dir / f"feats_{part}").as_posix(), + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz") + + logging.info("All fbank computed for ReazonSpeech.") + (args.manifest_dir / ".reazonspeech-fbank.done").touch() + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/display_manifest_statistics.py b/egs/multi_ja_en/ASR/local/display_manifest_statistics.py new file mode 100644 index 000000000..ace1dd73f --- /dev/null +++ b/egs/multi_ja_en/ASR/local/display_manifest_statistics.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from lhotse import CutSet, load_manifest + +ARGPARSE_DESCRIPTION = """ +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in +pruned_transducer_stateless5/train.py for usage. +""" + + +def get_parser(): + parser = argparse.ArgumentParser( + description=ARGPARSE_DESCRIPTION, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + + return parser.parse_args() + + +def main(): + args = get_parser() + + for part in ["train", "dev"]: + path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz" + cuts: CutSet = load_manifest(path) + + print("\n---------------------------------\n") + print(path.name + ":") + cuts.describe() + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/prepare_char.py b/egs/multi_ja_en/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..27832ad1b --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script tokenizes the training transcript by CJK characters +# and saves the result to transcript_chars.txt, which is used +# to train the BPE model later. + +import argparse +import re +from pathlib import Path + +from tqdm.auto import tqdm + +from icefall.utils import tokenize_by_ja_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="Training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + with open(transcript_path, "w+", encoding="utf-8") as fout: + for line in tqdm(fin): + fout.write(tokenize_by_ja_char(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/prepare_lang.py b/egs/multi_ja_en/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py new file mode 100755 index 000000000..6134710ad --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bbpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +import re +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.byte_utils import byte_encode +from icefall.utils import str2bool, tokenize_by_ja_char + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + encode_words = [byte_encode(tokenize_by_ja_char(w)) for w in words] + words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bbpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_char.py b/egs/multi_ja_en/ASR/local/prepare_lang_char.py new file mode 100644 index 000000000..19c5f4a31 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_lang_char.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help=( + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" + ), + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + + sysdef_string = set(["", "", "", " "]) + + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + for sup in cut.supervisions: + token_set.update(sup.text) + + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] + args.lang_dir.mkdir(parents=True, exist_ok=True) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) + ) + + (args.lang_dir / "lang_type").write_text("char") + logging.info("Done.") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/prepare_words.py b/egs/multi_ja_en/ASR/local/prepare_words.py new file mode 120000 index 000000000..ef2b4eaf3 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/prepare_words.py @@ -0,0 +1 @@ +../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/text2segments.py b/egs/multi_ja_en/ASR/local/text2segments.py new file mode 100644 index 000000000..e0f3a15c4 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/text2segments.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# 2022 Xiaomi Corp. (authors: Weiji Zhuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input "text", which refers to the transcript file: + - text +and generates the output file with word segmentation implemented using MeCab: + - text_words_segmentation +""" + +import argparse +from multiprocessing import Pool + +import MeCab +from tqdm import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Japanese Word Segmentation for text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--num-process", + "-n", + default=20, + type=int, + help="the number of processes", + ) + parser.add_argument( + "--input-file", + "-i", + default="data/lang_char/text", + type=str, + help="the input text file", + ) + parser.add_argument( + "--output-file", + "-o", + default="data/lang_char/text_words_segmentation", + type=str, + help="the text implemented with word segmentation using MeCab", + ) + + return parser + + +def cut(lines): + if lines is not None: + mecab = MeCab.Tagger("-Owakati") # Use '-Owakati' option for word segmentation + segmented_line = mecab.parse(lines).strip() + return segmented_line.split() # Return as a list of words + else: + return None + + +def main(): + parser = get_parser() + args = parser.parse_args() + + num_process = args.num_process + input_file = args.input_file + output_file = args.output_file + + with open(input_file, "r", encoding="utf-8") as fr: + lines = fr.readlines() + + with Pool(processes=num_process) as p: + new_lines = list(tqdm(p.imap(cut, lines), total=len(lines))) + + with open(output_file, "w", encoding="utf-8") as fw: + for line in new_lines: + fw.write(" ".join(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/text2token.py b/egs/multi_ja_en/ASR/local/text2token.py new file mode 100755 index 000000000..ce64847c9 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/text2token.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# Copyright 2017 Johns Hopkins University (authors: Shinji Watanabe) +# 2022 Xiaomi Corp. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import codecs +import re +import sys +from typing import List + +from romkan import to_roma # Replace with python-romkan v0.2.1 + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--nchar", + "-n", + default=1, + type=int, + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", + ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" + ) + parser.add_argument("--space", default="", type=str, help="space symbol") + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symbols, e.g., etc.", + ) + parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "romaji"], + help="Transcript type. char/romaji", + ) + return parser + + +def token2id( + texts, token_table, token_type: str = "romaji", oov: str = "" +) -> List[List[int]]: + """Convert token to id. + Args: + texts: + The input texts, it refers to the Japanese text here. + token_table: + The token table is built based on "data/lang_xxx/token.txt" + token_type: + The type of token, such as "romaji". + oov: + Out of vocabulary token. When a word(token) in the transcript + does not exist in the token list, it is replaced with `oov`. + + Returns: + The list of ids for the input texts. + """ + if texts is None: + raise ValueError("texts can't be None!") + else: + oov_id = token_table[oov] + ids: List[List[int]] = [] + for text in texts: + chars_list = list(str(text)) + if token_type == "romaji": + text = [to_roma(c) for c in chars_list] + sub_ids = [ + token_table[txt] if txt in token_table else oov_id for txt in text + ] + ids.append(sub_ids) + return ids + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer + ) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(" ".join(x[: args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols :]) # noqa E203 + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + if args.trans_type == "romaji": + a = [to_roma(c) for c in list(str(a))] + + a = [a[j : j + n] for j in range(0, len(a), n)] # noqa E203 + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = "".join(a_flat) + print(a_chars) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/train_bbpe_model.py b/egs/multi_ja_en/ASR/local/train_bbpe_model.py new file mode 100755 index 000000000..d104f2717 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/train_bbpe_model.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import re +import shutil +import tempfile +from pathlib import Path + +import sentencepiece as spm + +from icefall import byte_encode +from icefall.utils import tokenize_by_ja_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def _convert_to_bchar(in_path: str, out_path: str): + with open(out_path, "w") as f: + for line in open(in_path, "r").readlines(): + f.write(byte_encode(tokenize_by_ja_char(line)) + "\n") + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + model_file = Path(model_prefix + ".model") + if model_file.is_file(): + print(f"{model_file} exists - skipping") + return + + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + temp = tempfile.NamedTemporaryFile() + train_text = temp.name + + _convert_to_bchar(args.transcript, train_text) + + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + + shutil.copyfile(model_file, f"{lang_dir}/bbpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..be18e65c1 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,355 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class ReazonSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/dev/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=False, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + input_transforms = [] + + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/multi_ja_en/ASR/local/utils/tokenizer.py b/egs/multi_ja_en/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..ba71cff89 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/utils/tokenizer.py @@ -0,0 +1,252 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/local/validate_manifest.py b/egs/multi_ja_en/ASR/local/validate_manifest.py new file mode 100644 index 000000000..7f67c64b6 --- /dev/null +++ b/egs/multi_ja_en/ASR/local/validate_manifest.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within cut time bounds + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + s = c.supervisions[0] + + # Removed because when the cuts were trimmed from supervisions, + # the start time of the supervision can be lesser than cut start time. + # https://github.com/lhotse-speech/lhotse/issues/813 + # if s.start < c.start: + # raise ValueError( + # f"{c.id}: Supervision start time {s.start} is less " + # f"than cut start time {c.start}" + # ) + + if s.end > c.end: + raise ValueError( + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" + ) + + +def main(): + args = get_args() + + manifest = Path(args.manifest) + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest(manifest) + assert isinstance(cut_set, CutSet) + + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/multi_ja_en/ASR/prepare.sh b/egs/multi_ja_en/ASR/prepare.sh new file mode 100755 index 000000000..7a6a63418 --- /dev/null +++ b/egs/multi_ja_en/ASR/prepare.sh @@ -0,0 +1,185 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +vocab_sizes=( + 2000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: musan" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Soft link fbank of musan" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: LibriSpeech" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 1: Soft link fbank of LibriSpeech" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 1 --stop-stage 1 and ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: ReazonSpeech" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 2: Soft link fbank of ReazonSpeech" + mkdir -p data/fbank + if [ -e ../../reazonspeech/ASR/data/manifests/.reazonspeech.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) . + cd .. + mkdir -p manifests + cd manifests + ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/feats_*) . + cd ../.. + else + log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 0 --stop-stage 2" + exit 1 + fi +fi + +# New Stage 3: Prepare char based lang for ReazonSpeech +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + lang_char_dir=data/lang_char + log "Stage 3: Prepare char based lang for ReazonSpeech" + mkdir -p $lang_char_dir + + # Prepare text + if [ ! -f $lang_char_dir/text ]; then + gunzip -c ../../reazonspeech/ASR/data/manifests/reazonspeech_supervisions_train.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $lang_char_dir/text + fi + + # jp word segmentation for text + if [ ! -f $lang_char_dir/text_words_segmentation ]; then + python3 ./local/text2segments.py \ + --input-file $lang_char_dir/text \ + --output-file $lang_char_dir/text_words_segmentation + fi + + cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt + + if [ ! -f $lang_char_dir/words.txt ]; then + python3 ./local/prepare_words.py \ + --input-file $lang_char_dir/words_no_ids.txt \ + --output-file $lang_char_dir/words.txt + fi + + if [ ! -f $lang_char_dir/L_disambig.pt ]; then + python3 ./local/prepare_char.py --lang-dir data/lang_char + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare Byte BPE based lang" + mkdir -p data/fbank + if [ ! -d ../../reazonspeech/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then + log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi + + if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5" + exit 1 + fi + + cd data/ + # if [ ! -d ./lang_char ]; then + # ln -svf $(realpath ../../../reazonspeech/ASR/data/lang_char) . + # fi + if [ ! -d ./lang_bpe_500 ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + fi + cd ../ + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ + > $lang_dir/text + + if [ ! -f $lang_dir/transcript_chars.txt ]; then + ./local/prepare_for_bpe_model.py \ + --lang-dir ./$lang_dir \ + --text $lang_dir/text + fi + + if [ ! -f $lang_dir/text_words_segmentation ]; then + python3 ./local/text2segments.py \ + --input-file ./data/lang_char/text \ + --output-file $lang_dir/text_words_segmentation + + cat ./data/lang_bpe_500/transcript_words.txt \ + >> $lang_dir/text_words_segmentation + fi + + cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt + + if [ ! -f $lang_dir/words.txt ]; then + python3 ./local/prepare_words.py \ + --input-file $lang_dir/words_no_ids.txt \ + --output-file $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ln -svf $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/ + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bbpe.model + fi + done +fi + +log "prepare.sh: PREPARATION DONE" diff --git a/egs/multi_ja_en/ASR/shared b/egs/multi_ja_en/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/multi_ja_en/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py b/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/beam_search.py b/egs/multi_ja_en/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/ctc_decode.py b/egs/multi_ja_en/ASR/zipformer/ctc_decode.py new file mode 120000 index 000000000..faa8bd562 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/ctc_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/ctc_decode.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py new file mode 100755 index 000000000..9acccfcf7 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/decode.py @@ -0,0 +1,792 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + +import argparse +import logging +import math +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + tokenize_by_ja_char, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(byte_encode(tokenize_by_ja_char(supervisions["text"]))), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [tokenize_by_ja_char(str(text)).split() for text in texts] + # print(texts) + # exit() + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device, weights_only=False) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = ReazonSpeechAsrDataModule(args) + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/zipformer/decode_stream.py b/egs/multi_ja_en/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/decoder.py b/egs/multi_ja_en/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py new file mode 100755 index 000000000..072679cfc --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py @@ -0,0 +1,1261 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer_for_ncnn_export_only import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + is_pnnx=True, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 30.0: + logging.debug( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + train_cuts = reazonspeech_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = reazonspeech_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = reazonspeech_corpus.valid_cuts() + valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + raise RuntimeError("Please don't use this file directly!") + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/zipformer/encoder_interface.py b/egs/multi_ja_en/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/export-onnx.py b/egs/multi_ja_en/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/export.py b/egs/multi_ja_en/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 000000000..5a015ee6c --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/joiner.py b/egs/multi_ja_en/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/model.py b/egs/multi_ja_en/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..b0cdc1f6a --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py @@ -0,0 +1,143 @@ +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Dict + +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, args: argparse.Namespace): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - reazonspeech_cuts_train.jsonl.gz + - librispeech_cuts_train-clean-100.jsonl.gz + - librispeech_cuts_train-clean-360.jsonl.gz + - librispeech_cuts_train-other-500.jsonl.gz + """ + self.fbank_dir = Path(args.manifest_dir) + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + logging.info("Loading Reazonspeech in lazy mode") + reazonspeech_cuts = load_manifest_lazy( + self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz" + ) + + logging.info("Loading LibriSpeech in lazy mode") + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + + return CutSet.mux( + reazonspeech_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + len(reazonspeech_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + ], + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + logging.info("Loading Reazonspeech DEV set in lazy mode") + reazonspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" + ) + + logging.info("Loading LibriSpeech DEV set in lazy mode") + dev_clean_cuts = self.dev_clean_cuts() + dev_other_cuts = self.dev_other_cuts() + + return CutSet.mux( + reazonspeech_dev_cuts, + dev_clean_cuts, + dev_other_cuts, + weights=[ + len(reazonspeech_dev_cuts), + len(dev_clean_cuts), + len(dev_other_cuts), + ], + ) + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + logging.info("Loading Reazonspeech set in lazy mode") + reazonspeech_test_cuts = load_manifest_lazy( + self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz" + ) + reazonspeech_dev_cuts = load_manifest_lazy( + self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz" + ) + + logging.info("Loading LibriSpeech set in lazy mode") + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + + test_cuts = { + "reazonspeech_test": reazonspeech_test_cuts, + "reazonspeech_dev": reazonspeech_dev_cuts, + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, + } + + return test_cuts + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/multi_ja_en/ASR/zipformer/my_profile.py b/egs/multi_ja_en/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/onnx_decode.py b/egs/multi_ja_en/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/optim.py b/egs/multi_ja_en/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/pretrained.py b/egs/multi_ja_en/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/scaling.py b/egs/multi_ja_en/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/scaling_converter.py b/egs/multi_ja_en/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..935f86de1 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py @@ -0,0 +1,935 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +Monolingual: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp-large \ + --lang data/lang_char \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 + +Bilingual: +./zipformer/streaming_decode.py \ + --bilingual 1 \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp-large \ + --lang data/lang_char \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + +""" + +import argparse +import logging +import math +import os +import pdb +import subprocess as sp +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import ReazonSpeechAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bilingual", + type=str2bool, + default=False, + help="Whether the model is bilingual or not. 1 = bilingual.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=model.device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=model.device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + # finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + + if not finished_streams: + print("No finished streams, breaking the loop") + break + + for i in sorted(finished_streams, reverse=True): + try: + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + except IndexError as e: + print(f"IndexError: {e}") + print(f"decode_streams length: {len(decode_streams)}") + print(f"finished_streams: {finished_streams}") + print(f"i: {i}") + continue + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + torch.cuda.synchronize() + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + if not params.bilingual: + sp = Tokenizer.load(params.lang, params.lang_type) + else: + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + + if params.bilingual: + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + test_sets = test_sets_cuts.keys() + test_cuts = [test_sets_cuts[k] for k in test_sets] + + valid_cuts = reazonspeech_corpus.valid_cuts() + test_cuts = reazonspeech_corpus.test_cuts() + + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + logging.info(f"Decoding {test_set}") + if params.bilingual: + test_cut = test_cut.filter(remove_short_utt) + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/zipformer/subsampling.py b/egs/multi_ja_en/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/test_scaling.py b/egs/multi_ja_en/ASR/zipformer/test_scaling.py new file mode 120000 index 000000000..715798436 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/test_scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_scaling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/test_subsampling.py b/egs/multi_ja_en/ASR/zipformer/test_subsampling.py new file mode 120000 index 000000000..bf0ee3d11 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_subsampling.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/tokenizer.py b/egs/multi_ja_en/ASR/zipformer/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/multi_ja_en/ASR/zipformer/train.py b/egs/multi_ja_en/ASR/zipformer/train.py new file mode 100755 index 000000000..bfb037f50 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/train.py @@ -0,0 +1,1462 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --bilingual 1 \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 600 + +# For streaming model training: +./zipformer/train.py \ + --bilingual 1 \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 600 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + +import argparse +import copy +import logging +import os +import re +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import ReazonSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_ja_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bilingual", + type=str2bool, + default=False, + help="Whether the model is bilingual or not. 1 = bilingual.", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + # changed - not used in monolingual streaming + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.015, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + sentencepiece_processor: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + if not params.bilingual: + y = tokenizer.encode(texts, out_type=int) + else: + y = sentencepiece_processor.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + losses = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + simple_loss, pruned_loss, ctc_loss = losses[:3] + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + sentencepiece_processor: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + tokenizer: Tokenizer, + sentencepiece_processor: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch( + batch, + params=params, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + ) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + # Use lang_dir for further operations + # tokenizer = Tokenizer.load(args.lang, args.lang_type) + + # sentencepiece_processor = spm.SentencePieceProcessor() + # sentencepiece_processor.load(params.bpe_model) + tokenizer = None + sentencepiece_processor = None + + # is defined in local/prepare_lang_char.py + + if not params.bilingual: + tokenizer = Tokenizer.load(args.lang, args.lang_type) + params.blank_id = tokenizer.piece_to_id("") + params.vocab_size = tokenizer.get_piece_size() + else: + sentencepiece_processor = spm.SentencePieceProcessor() + sentencepiece_processor.load(params.bpe_model) + params.blank_id = sentencepiece_processor.piece_to_id("") + params.vocab_size = sentencepiece_processor.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + if params.bilingual: + multi_dataset = MultiDataset(args) + train_cuts = multi_dataset.train_cuts() + else: + train_cuts = reazonspeech_corpus.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + # if c.duration < 1.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + # return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_samples - 7) // 2 + 1) // 2 + if not params.bilingual: + tokens = tokenizer.encode(c.supervisions[0].text, out_type=str) + else: + tokens = sentencepiece_processor.encode( + c.supervisions[0].text, out_type=str + ) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_samples}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_ja_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.bilingual: + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = reazonspeech_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + if params.bilingual: + valid_cuts = reazonspeech_corpus.valid_cuts() + else: + valid_cuts = multi_dataset.dev_cuts() + valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + tokenizer: Tokenizer, + sentencepiece_processor: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + tokenizer: + The BPE Tokenizer model. + sentencepiece_processor: + The BPE SentencePieceProcessor model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + if params.bilingual: + y = sentencepiece_processor.encode(supervisions["text"], out_type=int) + else: + y = tokenizer.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + tokenizer: Tokenizer, + sentencepiece_processor: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch( + batch, + params=params, + tokenizer=tokenizer, + sentencepiece_processor=sentencepiece_processor, + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_ja_en/ASR/zipformer/zipformer.py b/egs/multi_ja_en/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/multi_ja_en/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 6f75dbfa4..2bbe28560 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -52,13 +52,19 @@ def get_parser(): default=80, help="""The number of mel bins for Fbank""", ) - parser.add_argument( "--whisper-fbank", type=str2bool, default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser @@ -104,6 +110,9 @@ def compute_fbank_kespeech_dev_test(args): keep_overlapping=False, min_duration=None ) + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index c398411f6..fe7f87337 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -106,6 +106,14 @@ def get_parser(): default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser @@ -158,6 +166,9 @@ def compute_fbank_kespeech_splits(args): keep_overlapping=False, min_duration=None ) + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py index f758f546c..5b9665c5a 100755 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -90,10 +90,10 @@ def average_checkpoints( """ n = len(filenames) - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + if "model" in torch.load(filenames[0], map_location=device, weights_only=False): + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device, weights_only=False) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -108,10 +108,10 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + if "model" in torch.load(filenames[i], map_location=device, weights_only=False): + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device, weights_only=False) for k in uniqued_names: avg[k] += state_dict[k] @@ -484,7 +484,7 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -513,7 +513,7 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index fe2d950c1..3ffaef212 100755 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -809,7 +809,7 @@ def run(rank, world_size, args): del model.alignment_heads if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) else: diff --git a/egs/multi_zh-hans/ASR/zipformer/decode.py b/egs/multi_zh-hans/ASR/zipformer/decode.py index a1d018cd2..e2f7bd678 100755 --- a/egs/multi_zh-hans/ASR/zipformer/decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/decode.py @@ -784,7 +784,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/multi_zh-hans/ASR/zipformer/export_rknn_ctc_streaming.py b/egs/multi_zh-hans/ASR/zipformer/export_rknn_ctc_streaming.py new file mode 120000 index 000000000..761e399ba --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export_rknn_ctc_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export_rknn_ctc_streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/export_rknn_transducer_streaming.py b/egs/multi_zh-hans/ASR/zipformer/export_rknn_transducer_streaming.py new file mode 120000 index 000000000..8be19ef3d --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export_rknn_transducer_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export_rknn_transducer_streaming.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py index 68111fad7..0164456b3 100755 --- a/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py +++ b/egs/multi_zh-hans/ASR/zipformer/generate_averaged_model.py @@ -24,7 +24,7 @@ Usage: --exp-dir ./zipformer/exp It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. +You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`. (2) use the checkpoint exp_dir/checkpoint-iter.pt ./zipformer/generate_averaged_model.py \ @@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`. --exp-dir ./zipformer/exp It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. +You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`. """ diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py index c15db11f7..53be57fae 100755 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -291,7 +291,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() @@ -328,9 +328,14 @@ def main(): logging.info(msg) def token_ids_to_words(token_ids: List[int]) -> str: - text = "" + byte_list = [] for i in token_ids: - text += token_table[i] + token = token_table[i] + if token.startswith("<0x") and token.endswith(">"): + byte_list.append(int(token[3:-1], 16)) + else: + byte_list += list(token.encode("utf-8")) + text = bytes(byte_list).decode("utf-8", errors='ignore') return text.replace("▁", " ").strip() if params.method == "fast_beam_search": diff --git a/egs/multi_zh-hans/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py b/egs/multi_zh-hans/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py new file mode 120000 index 000000000..6417f470f --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py index e21e8f052..b5b87af41 100755 --- a/egs/multi_zh_en/ASR/zipformer/decode.py +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -792,7 +792,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py index 68111fad7..0164456b3 100755 --- a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py +++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py @@ -24,7 +24,7 @@ Usage: --exp-dir ./zipformer/exp It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. -You can later load it by `torch.load("epoch-28-avg-15.pt")`. +You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`. (2) use the checkpoint exp_dir/checkpoint-iter.pt ./zipformer/generate_averaged_model.py \ @@ -33,7 +33,7 @@ You can later load it by `torch.load("epoch-28-avg-15.pt")`. --exp-dir ./zipformer/exp It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. -You can later load it by `torch.load("iter-22000-avg-5.pt")`. +You can later load it by `torch.load("iter-22000-avg-5.pt", weights_only=False)`. """ diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py index 2fcde550b..0f8de5020 100755 --- a/egs/multi_zh_en/ASR/zipformer/pretrained.py +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -294,7 +294,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index bed3856e4..dcc888de8 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -64,7 +64,7 @@ def main(): if out_lm_data.is_file(): logging.warning(f"{out_lm_data} exists - skipping") return - data = torch.load(in_lm_data) + data = torch.load(in_lm_data, weights_only=False) words2bpe = data["words"] sentences = data["sentences"] sentence_lengths = data["sentence_lengths"] diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py index 3790045fa..aedca9d5e 100755 --- a/egs/ptb/LM/local/test_prepare_lm_training_data.py +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -37,7 +37,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(str(bpe_model)) - data = torch.load(lm_training_data) + data = torch.load(lm_training_data, weights_only=False) words2bpe = data["words"] sentences = data["sentences"] diff --git a/egs/reazonspeech/ASR/RESULTS.md b/egs/reazonspeech/ASR/RESULTS.md index c0b4fe54a..92610d75b 100644 --- a/egs/reazonspeech/ASR/RESULTS.md +++ b/egs/reazonspeech/ASR/RESULTS.md @@ -47,3 +47,41 @@ The decoding command is: --blank-penalty 0 ``` +#### Streaming + +We have not completed evaluation of our models yet and will add evaluation results here once it's completed. + +The training command is: +```shell +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --lang data/lang_char \ + --max-duration 1600 +``` + +The decoding command is: + +```shell +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp-large \ + --lang data/lang_char \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +``` + diff --git a/egs/reazonspeech/ASR/local/utils/tokenizer.py b/egs/reazonspeech/ASR/local/utils/tokenizer.py index c9be72be1..ba71cff89 100644 --- a/egs/reazonspeech/ASR/local/utils/tokenizer.py +++ b/egs/reazonspeech/ASR/local/utils/tokenizer.py @@ -12,7 +12,6 @@ class Tokenizer: @staticmethod def add_arguments(parser: argparse.ArgumentParser): group = parser.add_argument_group(title="Lang related options") - group.add_argument("--lang", type=Path, help="Path to lang directory.") group.add_argument( diff --git a/egs/reazonspeech/ASR/zipformer/decode.py b/egs/reazonspeech/ASR/zipformer/decode.py index cdd2145f2..7b180bb02 100755 --- a/egs/reazonspeech/ASR/zipformer/decode.py +++ b/egs/reazonspeech/ASR/zipformer/decode.py @@ -1008,7 +1008,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 4c18c7563..7e3199e09 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,28 +18,23 @@ """ Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --lang data/lang_char \ - --num-decode-streams 2000 +./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192 + """ import argparse import logging import math +import os +import pdb +import subprocess as sp from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 import numpy as np import torch -import torch.nn as nn from asr_datamodule import ReazonSpeechAsrDataModule -from decode import save_results from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -48,9 +44,9 @@ from streaming_beam_search import ( modified_beam_search, ) from tokenizer import Tokenizer +from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -58,7 +54,14 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) LOG_EPS = math.log(1e-10) @@ -73,7 +76,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -87,12 +90,6 @@ def get_parser(): """, ) - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - parser.add_argument( "--avg", type=int, @@ -116,7 +113,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="zipformer/exp", help="The experiment dir", ) @@ -127,6 +124,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -138,14 +142,6 @@ def get_parser(): """, ) - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - parser.add_argument( "--num_active_paths", type=int, @@ -157,7 +153,7 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4.0, + default=4, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. @@ -194,18 +190,235 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - add_model_arguments(parser) return parser +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -224,27 +437,32 @@ def decode_one_chunk( Returns: Return a List containing which DecodeStreams are finished. """ - device = model.device + # pdb.set_trace() + # print(model) + # print(model.device) + # device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) features = [] feature_lens = [] states = [] - processed_lens = [] + processed_lens = [] # Used in fast-beam-search for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - feature_lens = torch.tensor(feature_lens, device=device) + feature_lens = torch.tensor(feature_lens, device=model.device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -256,12 +474,14 @@ def decode_one_chunk( ) states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -269,6 +489,7 @@ def decode_one_chunk( if params.decoding_method == "greedy_search": greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=model.device) processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( model=model, @@ -295,8 +516,9 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) + # if decode_streams[i].done: + # finished_streams.append(i) + finished_streams.append(i) return finished_streams @@ -305,7 +527,7 @@ def decode_dataset( cuts: CutSet, params: AttributeDict, model: nn.Module, - sp: Tokenizer, + tokenizer: Tokenizer, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -317,7 +539,7 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: + tokenizer: The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used @@ -338,14 +560,14 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 50 + log_interval = 100 decode_results = [] # Contain decode streams currently running. decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -361,15 +583,19 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] - + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text decode_streams.append(decode_stream) while len(decode_streams) >= params.num_decode_streams: @@ -380,8 +606,8 @@ def decode_dataset( decode_results.append( ( decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), + decode_streams[i].ground_truth.split(), + tokenizer.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -391,18 +617,37 @@ def decode_dataset( # decode final chunks of last sequences while len(decode_streams): + # print("INSIDE LEN DECODE STREAMS") + # pdb.set_trace() + # print(model.device) + # test_device = model.device + # print("done") finished_streams = decode_one_chunk( params=params, model=model, decode_streams=decode_streams ) + # print('INSIDE FOR LOOP ') + # print(finished_streams) + + if not finished_streams: + print("No finished streams, breaking the loop") + break + for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), + try: + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + tokenizer.decode(decode_streams[i].decoding_result()).split(), + ) ) - ) - del decode_streams[i] + del decode_streams[i] + except IndexError as e: + print(f"IndexError: {e}") + print(f"decode_streams length: {len(decode_streams)}") + print(f"finished_streams: {finished_streams}") + print(f"i: {i}") + continue if params.decoding_method == "greedy_search": key = "greedy_search" @@ -416,9 +661,54 @@ def decode_dataset( key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + torch.cuda.synchronize() return {key: decode_results} +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + @torch.no_grad() def main(): parser = get_parser() @@ -430,16 +720,20 @@ def main(): params = get_params() params.update(vars(args)) - if not params.res_dir: - params.res_dir = params.exp_dir / "streaming" / params.decoding_method + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -455,21 +749,21 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) + device = torch.device("cuda", 0) logging.info(f"Device: {device}") - sp = Tokenizer.load(params.lang, params.lang_type) + sp_token = Tokenizer.load(params.lang, params.lang_type) - # and is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # and is defined in local/train_bpe_model.py + params.blank_id = sp_token.piece_to_id("") + params.unk_id = sp_token.piece_to_id("") + params.vocab_size = sp_token.get_piece_size() logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: @@ -553,42 +847,51 @@ def main(): model.device = device decoding_graph = None - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif params.decoding_method == "fast_beam_search": + if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + # we need cut ids to display recognition results. args.return_cuts = True reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - for subdir in ["valid"]: + valid_cuts = reazonspeech_corpus.valid_cuts() + test_cuts = reazonspeech_corpus.test_cuts() + + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( - cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), + cuts=test_cut, params=params, model=model, - sp=sp, + tokenizer=sp_token, decoding_graph=decoding_graph, ) - tot_err = save_results( - params=params, test_set_name=subdir, results_dict=results_dict + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, ) - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + # valid_cuts = reazonspeech_corpus.valid_cuts() + + # for valid_cut in valid_cuts: + # results_dict = decode_dataset( + # cuts=valid_cut, + # params=params, + # model=model, + # sp=sp, + # decoding_graph=decoding_graph, + # ) + # save_results( + # params=params, + # test_set_name="valid", + # results_dict=results_dict, + # ) logging.info("Done!") diff --git a/egs/speech_llm/ASR_LLM/RESULTS.md b/egs/speech_llm/ASR_LLM/RESULTS.md index 830c70397..42dce80c5 100644 --- a/egs/speech_llm/ASR_LLM/RESULTS.md +++ b/egs/speech_llm/ASR_LLM/RESULTS.md @@ -42,8 +42,8 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel # For multi-hans fine-tuned whisper model # huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct +# huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-cli download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct # First, we only train the projector and freeze other modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ @@ -55,9 +55,10 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora False --unfreeze-llm False + --use-lora False \ + --unfreeze-llm False -# Then we jointly train the projector and LLM LoRA modules. +# Then, we jointly train the projector and LLM LoRA modules. torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --max-duration 200 \ --exp-dir ./whisper_llm_zh/exp_test \ @@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ --deepspeed \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --use-flash-attn True \ - --use-lora True --unfreeze-llm True + --use-lora True \ + --unfreeze-llm True \ --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt ``` @@ -77,11 +79,11 @@ mkdir -p models/whisper models/qwen models/checkpoint huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B # For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt # For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt +# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct +huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt @@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \ --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --use-lora True --dataset aishell + --use-lora True \ + --dataset aishell ``` diff --git a/egs/speech_llm/ASR_LLM/prepare.sh b/egs/speech_llm/ASR_LLM/prepare.sh old mode 100644 new mode 100755 index 6f5ed5448..8ca3c1c36 --- a/egs/speech_llm/ASR_LLM/prepare.sh +++ b/egs/speech_llm/ASR_LLM/prepare.sh @@ -7,6 +7,9 @@ set -eou pipefail stage=0 stop_stage=0 + +. shared/parse_options.sh || exit 1 + # All files generated by this script are saved in "data". # You can safely remove "data" and rerun this script to regenerate it. mkdir -p data @@ -23,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # pip install huggingface_hub['cli'] # for aishell 1 - huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data yuekai/aishell_whisper_fbank_lhotse fi @@ -31,9 +34,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" # for multi-hans-zh - huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse - huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse - huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse + huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -41,6 +44,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # for speechio test sets mkdir data_speechio - huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio + huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio mv data_speechio/fbank/* data/fbank fi diff --git a/egs/speech_llm/ASR_LLM/shared b/egs/speech_llm/ASR_LLM/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/speech_llm/ASR_LLM/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 882ce4fbf..7c3901c20 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -95,10 +95,10 @@ def average_checkpoints( """ n = len(filenames) - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + if "model" in torch.load(filenames[0], map_location=device, weights_only=False): + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device, weights_only=False) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -113,10 +113,10 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + if "model" in torch.load(filenames[i], map_location=device, weights_only=False): + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device, weights_only=False) for k in uniqued_names: avg[k] += state_dict[k] @@ -357,43 +357,6 @@ def decode_dataset( Returns: Return a dict, whose key may be "beam-search". """ - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - results = [] num_cuts = 0 @@ -406,6 +369,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( @@ -418,12 +382,8 @@ def decode_dataset( for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_text = normalize_text_alimeeting(ref_text) - ref_words = ref_text.split() - print(f"ref: {ref_text}") - print(f"hyp: {''.join(hyp_words)}") - this_batch.append((cut_id, ref_words, hyp_words)) + for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_text)) results[lm_scale].extend(this_batch) @@ -439,40 +399,38 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - - enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recog_path, texts=results, char_level=True) + logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned + # The following prints out CERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -495,9 +453,13 @@ def main(): params = get_params() params.update(vars(args)) + + params.res_dir = params.exp_dir / f"{params.method}" + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + params.res_dir + / f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}" ) logging.info("Decoding started") @@ -574,23 +536,20 @@ def main(): if params.avg > 1: start = params.epoch - params.avg + 1 assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - assert "model" not in checkpoint # deepspeed converted checkpoint only contains model state_dict filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" + f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin" for epoch in range(start, params.epoch + 1) ] avg_checkpoint = average_checkpoints(filenames) model.load_state_dict(avg_checkpoint, strict=False) - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(avg_checkpoint, filename) + # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", weights_only=False, + map_location="cpu", ) model.load_state_dict(checkpoint, strict=False) @@ -643,8 +602,7 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5f224c984..7162af958 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # 2024 Yuekai Zhang +# 2025 Yifan Yang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ """ import argparse -import copy import logging import os -import random import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import deepspeed -import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn import transformers import whisper from asr_datamodule import AsrDataModule from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict -from label_smoothing import LabelSmoothingLoss -from lhotse import CutSet, load_manifest from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from multi_dataset import MultiDataset -from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall import diagnostics from icefall.dist import get_rank, get_world_size from icefall.env import get_env_info -from icefall.utils import ( - AttributeDict, - MetricsTracker, - filter_uneven_sized_batch, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool DEFAULT_SPEECH_TOKEN = "" @@ -286,13 +272,6 @@ def compute_loss( Returns: Return a tuple of two elements. The first element is the loss tensor. """ - # For the uneven-sized batch, the total duration after padding would possibly - # cause OOM. Hence, for each batch, which is sorted descendingly by length, - # we simply drop the last few shortest samples, so that the retained total frames - # (after padding) would not exceed `allowed_max_frames`: - # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - # where `max_frames = max_duration * 1000 // frame_shift_ms`. - # We set allowed_excess_duration_ratio=0.1. def preprocess( messages, @@ -347,46 +326,6 @@ def compute_loss( return input_ids, attention_mask, target_ids - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) - device = next(model.parameters()).device feature = batch["inputs"] @@ -397,11 +336,10 @@ def compute_loss( batch_idx_train = params.batch_idx_train supervisions = batch["supervisions"] texts = batch["supervisions"]["text"] - # remove spaces in texts - texts = [normalize_text_alimeeting(text) for text in texts] messages = [] for i, text in enumerate(texts): + text = text.replace(" ", "") message = [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "assistant", "content": text}, @@ -516,14 +454,17 @@ def train_one_epoch( The rank of the node in DDP training. If no DDP is used, it should be set to 0. """ - model.encoder_projector.train() + model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -533,6 +474,9 @@ def train_one_epoch( world_size=world_size, ) model.train() + model.encoder.eval() + if not params.unfreeze_llm: + model.llm.eval() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -648,7 +592,6 @@ def run(rank, world_size, args): speech_encoder_dim = whisper_model.dims.n_audio_state for name, param in speech_encoder.named_parameters(): param.requires_grad = False - speech_encoder.eval() tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) if params.use_flash_attn: @@ -671,7 +614,6 @@ def run(rank, world_size, args): if not params.unfreeze_llm: for name, param in llm.named_parameters(): param.requires_grad = False - llm.eval() else: if params.use_lora: lora_config = LoraConfig( @@ -710,7 +652,7 @@ def run(rank, world_size, args): ) if params.pretrained_model_path: - checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) num_param = sum([p.numel() for p in model.parameters()]) @@ -728,7 +670,7 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - assert params.deepspeed and world_size > 1 + assert params.deepspeed logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( args=params, model=model, model_parameters=model.parameters() @@ -762,9 +704,9 @@ def run(rank, world_size, args): sampler_state_dict = None if params.sampler_state_dict_path: - sampler_state_dict = torch.load(params.sampler_state_dict_path) + sampler_state_dict = torch.load(params.sampler_state_dict_path, weights_only=False) sampler_state_dict["max_duration"] = params.max_duration - # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) @@ -806,15 +748,15 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", client_state={}, exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint @@ -824,7 +766,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") logging.info("Done!") @@ -865,6 +807,7 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) + warnings.filterwarnings("ignore", category=FutureWarning) run(rank=rank, world_size=world_size, args=args) diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index c20f1f714..9ee3ecd04 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -91,10 +91,10 @@ def average_checkpoints( """ n = len(filenames) - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + if "model" in torch.load(filenames[0], map_location=device, weights_only=False): + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device, weights_only=False) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -109,10 +109,10 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + if "model" in torch.load(filenames[i], map_location=device, weights_only=False): + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device, weights_only=False) for k in uniqued_names: avg[k] += state_dict[k] @@ -447,7 +447,7 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -476,7 +476,7 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/speechio/ASR/zipformer/decode.py b/egs/speechio/ASR/zipformer/decode.py index ffdd7b500..62a7e8943 100644 --- a/egs/speechio/ASR/zipformer/decode.py +++ b/egs/speechio/ASR/zipformer/decode.py @@ -784,7 +784,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/spgispeech/ASR/zipformer/decode.py b/egs/spgispeech/ASR/zipformer/decode.py index 90d318919..7cc23d1f0 100755 --- a/egs/spgispeech/ASR/zipformer/decode.py +++ b/egs/spgispeech/ASR/zipformer/decode.py @@ -988,7 +988,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/spgispeech/ASR/zipformer/pretrained.py b/egs/spgispeech/ASR/zipformer/pretrained.py index a562fb9f6..a2f8e5544 100755 --- a/egs/spgispeech/ASR/zipformer/pretrained.py +++ b/egs/spgispeech/ASR/zipformer/pretrained.py @@ -291,7 +291,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py index 52e501ae1..9e28043ab 100755 --- a/egs/swbd/ASR/conformer_ctc/decode.py +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -698,7 +698,7 @@ def main(): H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -738,7 +738,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) if params.method in [ diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py index bed3856e4..dcc888de8 100755 --- a/egs/swbd/ASR/local/sort_lm_training_data.py +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -64,7 +64,7 @@ def main(): if out_lm_data.is_file(): logging.warning(f"{out_lm_data} exists - skipping") return - data = torch.load(in_lm_data) + data = torch.load(in_lm_data, weights_only=False) words2bpe = data["words"] sentences = data["sentences"] sentence_lengths = data["sentence_lengths"] diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py index c8cf9b881..aa23c4cb3 100755 --- a/egs/tal_csasr/ASR/local/prepare_lang.py +++ b/egs/tal_csasr/ASR/local/prepare_lang.py @@ -28,7 +28,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index 8a74ee745..098ea3f4c 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -235,7 +235,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py index 885778965..f4361b528 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/decode.py @@ -766,7 +766,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index 6e07b5949..21d80bfef 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -248,7 +248,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py index 28d39de70..220c7a6c1 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/decode.py +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -675,7 +675,7 @@ def main() -> None: H = None bpe_model = None HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device, weights_only=False) ) assert HLG.requires_grad is False @@ -687,7 +687,7 @@ def main() -> None: if params.lm_path.is_file() and params.lm_path.suffix == ".pt": logging.info(f"Loading pre-compiled {params.lm_path.name}") - d = torch.load(params.lm_path, map_location=device) + d = torch.load(params.lm_path, map_location=device, weights_only=False) G = k2.Fsa.from_dict(d) elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt": raise FileNotFoundError(f"No such language model file: '{params.lm_path}'") diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 9e58fed00..f0a32a993 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -238,7 +238,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index 170f37767..2455f3630 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -422,7 +422,7 @@ def compute_loss( texts = batch["supervisions"]["text"] unk_id = params.unk_id - y = convert_texts_into_ids(texts, unk_id, sp=sp) + y = convert_texts_into_ids(texts, sp=sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index 5300fe764..73e18e20d 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -257,7 +257,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 6fed32e81..c6fa34e70 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -397,7 +397,7 @@ def compute_loss( texts = batch["supervisions"]["text"] unk_id = params.unk_id - y = convert_texts_into_ids(texts, unk_id, sp=sp) + y = convert_texts_into_ids(texts, sp=sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py index 2c4123c20..7f8a7ef3e 100755 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ b/egs/tedlium3/ASR/zipformer/decode.py @@ -784,7 +784,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py index c8562f4fb..150fcbc60 100755 --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -63,11 +63,11 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) if Path("data/lm/G.pt").is_file(): logging.info("Loading pre-compiled G") - d = torch.load("data/lm/G.pt") + d = torch.load("data/lm/G.pt", weights_only=False) G = k2.Fsa.from_dict(d) else: logging.info("Loading G_3_gram.fst.txt") diff --git a/egs/timit/ASR/local/prepare_lang.py b/egs/timit/ASR/local/prepare_lang.py index e9f283274..d5087ca67 100755 --- a/egs/timit/ASR/local/prepare_lang.py +++ b/egs/timit/ASR/local/prepare_lang.py @@ -29,7 +29,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index 4beeed18c..541ff09a0 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -398,7 +398,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -424,7 +424,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 0d77bc512..78b17558c 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -167,13 +167,13 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +181,7 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index 502a48def..f3eebcc61 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -397,7 +397,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False)) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -423,7 +423,7 @@ def main(): torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") else: logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu", weights_only=False) G = k2.Fsa.from_dict(d).to(device) if params.method == "whole-lattice-rescoring": diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index f06c8c211..a1e93b329 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -167,13 +167,13 @@ def main(): subsampling_factor=params.subsampling_factor, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu", weights_only=False)) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -181,7 +181,7 @@ def main(): if params.method == "whole-lattice-rescoring": logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu", weights_only=False)) # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = G.to(device) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 8b35187b1..2cc06df71 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -106,7 +106,7 @@ class WenetSpeechAsrDataModule: group.add_argument( "--num-buckets", type=int, - default=30, + default=15, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -292,8 +292,7 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 2bafe25d6..65afad8f0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -640,7 +640,7 @@ def main(): lg_filename = params.lang_dir + "/LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c34f1593d..d03b5485c 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -477,7 +477,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index 642de72d7..51c4c13c0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -220,7 +220,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py index 17428e19d..f35042c07 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -220,7 +220,7 @@ def main(): logging.info("Creating model") model = get_transducer_model(params) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py index 34b1c80ef..2363a6992 100755 --- a/egs/wenetspeech/ASR/whisper/decode.py +++ b/egs/wenetspeech/ASR/whisper/decode.py @@ -88,10 +88,10 @@ def average_checkpoints( """ n = len(filenames) - if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + if "model" in torch.load(filenames[0], map_location=device, weights_only=False): + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device, weights_only=False) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -106,10 +106,10 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + if "model" in torch.load(filenames[i], map_location=device, weights_only=False): + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device, weights_only=False) for k in uniqued_names: avg[k] += state_dict[k] @@ -435,7 +435,7 @@ def main(): start = params.epoch - params.avg assert start >= 1, start checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: # deepspeed converted checkpoint only contains model state_dict @@ -464,7 +464,7 @@ def main(): torch.save(model.state_dict(), filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu", weights_only=False ) if "model" not in checkpoint: model.load_state_dict(checkpoint, strict=True) diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py index 0fbc8244b..63d29b7fd 100755 --- a/egs/wenetspeech/ASR/zipformer/decode.py +++ b/egs/wenetspeech/ASR/zipformer/decode.py @@ -757,7 +757,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/wenetspeech/ASR/zipformer/export_rknn_transducer_streaming.py b/egs/wenetspeech/ASR/zipformer/export_rknn_transducer_streaming.py new file mode 120000 index 000000000..8be19ef3d --- /dev/null +++ b/egs/wenetspeech/ASR/zipformer/export_rknn_transducer_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export_rknn_transducer_streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py new file mode 120000 index 000000000..8c203406b --- /dev/null +++ b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py new file mode 120000 index 000000000..6417f470f --- /dev/null +++ b/egs/wenetspeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/test_rknn_on_cpu_simulator_ctc_streaming.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 8472b8531..0af7c1595 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 2: Finetune the model" + log "Stage 3: Finetune the model" # The following configuration of lr schedule should work well # You may also tune the following parameters to adjust learning rate schedule @@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 1: Decode the finetuned model." + log "Stage 4: Decode the finetuned model." export CUDA_VISIBLE_DEVICES="0" for t in small large; do python ./zipformer/decode.py \ @@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 2: Export the finetuned model." + log "Stage 5: Export the finetuned model." python ./zipformer/export.py \ --epoch 10 \ diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py index 6425030eb..34014facc 100755 --- a/egs/wenetspeech/KWS/zipformer/decode-asr.py +++ b/egs/wenetspeech/KWS/zipformer/decode-asr.py @@ -706,7 +706,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 340a41231..a628c7e58 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -35,7 +35,6 @@ from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params from icefall import ContextGraph -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index d19172b38..72e786864 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -69,6 +69,7 @@ import argparse import copy import logging import warnings +from functools import partial from pathlib import Path from typing import List, Optional, Tuple, Union @@ -90,6 +91,7 @@ from train import ( add_training_arguments, compute_validation_loss, display_and_save_batch, + encode_text, get_adjusted_batch_count, get_model, get_params, @@ -100,7 +102,6 @@ from train import ( ) from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -110,11 +111,11 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -215,7 +216,7 @@ def load_model_params( """ logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") + checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False) # if module list is empty, load the whole model from ckpt if not init_modules: @@ -254,7 +255,6 @@ def load_model_params( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -289,7 +289,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): @@ -347,7 +347,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -418,7 +417,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -436,7 +434,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -523,7 +521,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -576,14 +573,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -601,6 +594,9 @@ def run(rank, world_size, args): if params.continue_finetune: assert params.start_epoch > 0, params.start_epoch + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) @@ -666,17 +662,10 @@ def run(rank, world_size, args): else: train_cuts = wenetspeech.nihaowenwen_train_cuts() - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -691,7 +680,7 @@ def run(rank, world_size, args): valid_cuts = wenetspeech.nihaowenwen_dev_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -699,7 +688,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -724,7 +712,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -760,6 +747,8 @@ def main(): WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 40960c2ae..5d9d8de36 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -53,6 +53,7 @@ import argparse import copy import logging import warnings +from functools import partial from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -776,7 +776,6 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -811,7 +810,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): @@ -859,7 +858,6 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -872,7 +870,6 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=False, ) @@ -895,7 +892,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -971,7 +967,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -988,7 +983,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -1077,7 +1072,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -1098,6 +1092,20 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict): + text = c.supervisions[0].text + tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ids = [] + for t in tokens: + if t in token_table: + ids.append(token_table[t]) + else: + logging.warning(f"Text : {text} has OOV token : {t} , encode to ") + ids.append(token_table[""]) + c.supervisions[0].tokens = ids + return c + + def run(rank, world_size, args): """ Args: @@ -1130,14 +1138,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -1216,17 +1220,10 @@ def run(rank, world_size, args): return True - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1240,7 +1237,7 @@ def run(rank, world_size, args): ) valid_cuts = wenetspeech.valid_cuts() - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -1248,7 +1245,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -1273,7 +1269,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1307,7 +1302,6 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1317,8 +1311,6 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. """ from lhotse.utils import uuid4 @@ -1332,8 +1324,8 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) + tokens = [c.supervisions[0].tokens for c in supervisions["cut"]] + num_tokens = sum(len(i) for i in tokens) logging.info(f"num tokens: {num_tokens}") @@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -1385,6 +1375,7 @@ def main(): args = parser.parse_args() args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md new file mode 100644 index 000000000..8329ae948 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/README.md @@ -0,0 +1,188 @@ +# Results +| Model | Seed-TTS test_zh CER | Comment | +|---------------------------------------|---------------------|--------| +| [vall-e](./valle) | 4.33% | ~150M | +| [f5-tts](./f5-tts) | 3.02% (16 steps) / 2.42% (32 steps) | F5-TTS-Small Config, ~155M | +| [f5-tts-semantic-token](./f5-tts) | 1.79% (16 steps) | Using pretrained cosyvoice2 semantic tokens as inputs rather than text tokens, ~155M | + +# Introduction + +[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo). + +Preparation: + +``` +bash prepare.sh +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} +``` + +# [F5-TTS](https://arxiv.org/abs/2410.06885) + +./f5-tts contains the code for training F5-TTS model. + +Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard). + +Preparation: + +``` +bash prepare.sh --stage 5 --stop_stage 6 +``` +(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).) + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference with Icefall Wenetspeech4TTS trained F5-Small, use: +``` +huggingface-cli login +huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset +huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt +# skip +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small + + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 +bash local/compute_wer.sh $output_dir $manifest +``` + +To inference with official Emilia trained F5-Base, use: +``` +huggingface-cli login +huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset +huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir +bash local/compute_wer.sh $output_dir $manifest +``` + +# F5-TTS-Semantic-Token + +./f5-tts contains the code for training F5-TTS-Semantic-Token. We replaced the text tokens in F5-TTS with pretrained cosyvoice2 semantic tokens. During inference, we use the pretrained CosyVoice2 LLM to predict the semantic tokens for target audios. We observed that this approach leads to faster convergence and improved prosody modeling results. + +Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic/tree/main). + +Preparation: + +``` +# extract cosyvoice2 semantic tokens +bash prepare.sh --stage 5 --stop_stage 7 +``` + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-semantic-token-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ + --num-epochs 10 --start-epoch 1 --start-batch 0 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} \ + --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True +``` + +To inference with Icefall Wenetspeech4TTS trained F5-Small-Semantic-Token, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x + +split=test_zh +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt + +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True +bash local/compute_wer.sh $output_dir $manifest +``` + +# Credits +- [VALL-E](https://github.com/lifeiteng/vall-e) +- [F5-TTS](https://github.com/SWivid/F5-TTS) +- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py b/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py new file mode 100644 index 000000000..e3d3ff308 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# Copyright 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +python3 bin/generate_averaged_model.py \ + --epoch 40 \ + --avg 5 \ + --exp-dir ${exp_dir} + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt", weights_only=False)`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, +) +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + add_model_arguments(parser) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + print("About to create model") + filename = f"{params.exp_dir}/epoch-{params.epoch}.pt" + checkpoint = torch.load(filename, map_location=device, weights_only=False) + args = AttributeDict(checkpoint) + model = get_model(args) + + if params.iter > 0: + # TODO FIX ME + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + filenames = [ + f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1) + ] + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + checkpoint["model"] = model.state_dict() + torch.save(checkpoint, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py new file mode 100644 index 000000000..52f57b187 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -0,0 +1,828 @@ +#!/usr/bin/env python3 +# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small + +# command for text token input +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 + +# command for cosyvoice semantic token input +split=test_zh # seed_tts_eval test_zh +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --split-name $split --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 --use-cosyvoice-semantic-token True + +bash local/compute_wer.sh $output_dir $manifest +""" +import argparse +import logging +import math +import os +import random +import time +from pathlib import Path + +import datasets +import torch +import torch.nn.functional as F +import torchaudio +from accelerate import Accelerator +from bigvganinference import BigVGANInference +from model.cfm import CFM +from model.dit import DiT +from model.modules import MelSpec +from model.utils import convert_char_to_pinyin +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--model-path", + type=str, + default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--nfe", + type=int, + default=16, + help="The number of steps for the neural ODE", + ) + + parser.add_argument( + "--manifest-file", + type=str, + default=None, + help="The manifest file in seed_tts_eval format", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default="results", + help="The output directory to save the generated wavs", + ) + + parser.add_argument("-ss", "--swaysampling", default=-1, type=float) + + parser.add_argument( + "--interpolate-token", + type=str2bool, + default=True, + help="Interpolate semantic token to match mel frames for CosyVoice", + ) + + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + + add_model_arguments(parser) + return parser.parse_args() + + +def get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio, ref_sr = torchaudio.load(prompt_wav) + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio))) + if ref_rms < target_rms: + ref_audio = ref_audio * target_rms / ref_rms + assert ( + ref_audio.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio) + + # Text + if len(prompt_text[-1].encode("utf-8")) == 1: + prompt_text = prompt_text + " " + text = [prompt_text + gt_text] + if tokenizer == "pinyin": + text_list = convert_char_to_pinyin(text, polyphone=polyphone) + else: + text_list = text + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + final_text_list[bucket_i].extend(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for i in range(len(dataset)): + utt = dataset[i]["id"] + ref_audio_org, ref_sr = ( + dataset[i]["prompt_audio"]["array"], + dataset[i]["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + audio_tokens = dataset[i]["target_audio_cosy2_tokens"] + prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"] + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + input_tokens = prompt_audio_tokens + audio_tokens + + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + if total_mel_len > max_tokens: + print( + f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + ) + continue + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def inference_speech_token( + cosyvoice, + tts_text, + prompt_text, + prompt_speech_16k, + stream=False, + speed=1.0, + text_frontend=True, +): + tokens = [] + prompt_text = cosyvoice.frontend.text_normalize( + prompt_text, split=False, text_frontend=text_frontend + ) + for i in cosyvoice.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ): + + tts_text_token, tts_text_token_len = cosyvoice.frontend._extract_text_token(i) + ( + prompt_text_token, + prompt_text_token_len, + ) = cosyvoice.frontend._extract_text_token(prompt_text) + speech_token, speech_token_len = cosyvoice.frontend._extract_speech_token( + prompt_speech_16k + ) + + for i in cosyvoice.model.llm.inference( + text=tts_text_token.to(cosyvoice.model.device), + text_len=torch.tensor([tts_text_token.shape[1]], dtype=torch.int32).to( + cosyvoice.model.device + ), + prompt_text=prompt_text_token.to(cosyvoice.model.device), + prompt_text_len=torch.tensor( + [prompt_text_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + prompt_speech_token=speech_token.to(cosyvoice.model.device), + prompt_speech_token_len=torch.tensor( + [speech_token.shape[1]], dtype=torch.int32 + ).to(cosyvoice.model.device), + embedding=None, + ): + tokens.append(i) + return tokens, speech_token + + +def get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + polyphone=True, + target_sample_rate=24000, + n_fft=1024, + win_length=1024, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + num_buckets=200, + min_secs=3, + max_secs=40, + interpolate_token=False, +): + + import sys + + # please change the path to the cosyvoice accordingly + sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + sys.path.append("/workspace/CosyVoice") + from cosyvoice.cli.cosyvoice import CosyVoice2 + + # please download the cosyvoice model first + cosyvoice = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False + ) + + prompts_all = [] + + min_tokens = min_secs * target_sample_rate // hop_length + max_tokens = max_secs * target_sample_rate // hop_length + + batch_accum = [0] * num_buckets + utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = ( + [[] for _ in range(num_buckets)] for _ in range(6) + ) + + mel_spectrogram = MelSpec( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=target_sample_rate, + mel_spec_type=mel_spec_type, + ) + + for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm( + metainfo, desc="Processing prompts..." + ): + # Audio + ref_audio_org, ref_sr = torchaudio.load(prompt_wav) + + # cosy voice + if ref_sr != 16000: + resampler = torchaudio.transforms.Resample(ref_sr, 16000) + ref_audio_16k = resampler(ref_audio_org) + else: + ref_audio_16k = ref_audio_org + audio_tokens, prompt_audio_tokens = inference_speech_token( + cosyvoice, gt_text, prompt_text, ref_audio_16k, stream=False + ) + + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + assert ( + ref_audio_org.shape[-1] > 5000 + ), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue." + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Text + # if len(prompt_text[-1].encode("utf-8")) == 1: + # prompt_text = prompt_text + " " + # text = [prompt_text + gt_text] + # if tokenizer == "pinyin": + # text_list = convert_char_to_pinyin(text, polyphone=polyphone) + # else: + # text_list = text + + # concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens + # prompt_audio_tokens shape 1, prompt_audio_tokens + # audio_tokens shape 1, audio_tokens + prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist() + input_tokens = prompt_audio_tokens + audio_tokens + + # convert it into a list + # input_tokens_list = input_tokens.squeeze().cpu().tolist() + if interpolate_token: + input_tokens = interpolate_tokens(input_tokens) + text_list = input_tokens + + # Duration, mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + if use_truth_duration: + gt_audio, gt_sr = torchaudio.load(gt_wav) + if gt_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate) + gt_audio = resampler(gt_audio) + total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed) + + # # test vocoder resynthesis + # ref_audio = gt_audio + else: + ref_text_len = len(prompt_text.encode("utf-8")) + gen_text_len = len(gt_text.encode("utf-8")) + total_mel_len_compute = ref_mel_len + int( + ref_mel_len / ref_text_len * gen_text_len / speed + ) + total_mel_len = len(input_tokens) + if not interpolate_token: + total_mel_len = int(total_mel_len / 4 * 15) + print( + f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}" + ) + + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + # deal with batch + assert infer_batch_size > 0, "infer_batch_size should be greater than 0." + assert ( + min_tokens <= total_mel_len <= max_tokens + ), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]." + bucket_i = math.floor( + (total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets + ) + + utts[bucket_i].append(utt) + ref_rms_list[bucket_i].append(ref_rms) + ref_mels[bucket_i].append(ref_mel) + ref_mel_lens[bucket_i].append(ref_mel_len) + total_mel_lens[bucket_i].append(total_mel_len) + # final_text_list[bucket_i].extend(text_list) + final_text_list[bucket_i].append(text_list) + + batch_accum[bucket_i] += total_mel_len + + if batch_accum[bucket_i] >= infer_batch_size: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + batch_accum[bucket_i] = 0 + ( + utts[bucket_i], + ref_rms_list[bucket_i], + ref_mels[bucket_i], + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # add residual + for bucket_i, bucket_frames in enumerate(batch_accum): + if bucket_frames > 0: + prompts_all.append( + ( + utts[bucket_i], + ref_rms_list[bucket_i], + padded_mel_batch(ref_mels[bucket_i]), + ref_mel_lens[bucket_i], + total_mel_lens[bucket_i], + final_text_list[bucket_i], + ) + ) + # not only leave easy work for last workers + random.seed(666) + random.shuffle(prompts_all) + + return prompts_all + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def get_seedtts_testset_metainfo(metalst): + f = open(metalst) + lines = f.readlines() + f.close() + metainfo = [] + for line in lines: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav) + metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav)) + return metainfo + + +def main(): + args = get_parser() + + accelerator = Accelerator() + device = f"cuda:{accelerator.process_index}" + if args.manifest_file: + metainfo = get_seedtts_testset_metainfo(args.manifest_file) + if not args.use_cosyvoice_semantic_token: + prompts_all = get_inference_prompt( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + ) + else: + prompts_all = get_inference_prompt_cosy_voice( + metainfo, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + else: + assert args.use_cosyvoice_semantic_token + dataset = datasets.load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + prompts_all = get_inference_prompt_cosy_voice_huggingface( + dataset, + speed=1.0, + tokenizer="pinyin", + target_sample_rate=24_000, + n_mel_channels=100, + hop_length=256, + mel_spec_type="bigvgan", + target_rms=0.1, + use_truth_duration=False, + infer_batch_size=1, + interpolate_token=args.interpolate_token, + ) + + vocoder = BigVGANInference.from_pretrained( + "./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False + ) + vocoder = vocoder.eval().to(device) + + model = get_model(args).eval().to(device) + checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False) + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) + else: + _ = load_checkpoint( + args.model_path, + model=model, + ) + + os.makedirs(args.output_dir, exist_ok=True) + + accelerator.wait_for_everyone() + start = time.time() + + with accelerator.split_between_processes(prompts_all) as prompts: + for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process): + ( + utts, + ref_rms_list, + ref_mels, + ref_mel_lens, + total_mel_lens, + final_text_list, + ) = prompt + ref_mels = ref_mels.to(device) + ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + + if args.use_cosyvoice_semantic_token: + # concat final_text_list + max_len = max([len(tokens) for tokens in final_text_list]) + # pad tokens to the same length + for i, tokens in enumerate(final_text_list): + final_text_list[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + final_text_list = torch.stack(final_text_list).to(device) + + # Inference + with torch.inference_mode(): + generated, _ = model.sample( + cond=ref_mels, + text=final_text_list, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=args.nfe, + cfg_strength=2.0, + sway_sampling_coef=args.swaysampling, + no_ref_audio=False, + seed=args.seed, + ) + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + if ref_rms_list[i] < target_rms: + generated_wave = generated_wave * ref_rms_list[i] / target_rms + torchaudio.save( + f"{args.output_dir}/{utts[i]}.wav", + generated_wave, + target_sample_rate, + ) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + timediff = time.time() - start + print(f"Done batch inference in {timediff / 60 :.2f} minutes.") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/README.md b/egs/wenetspeech4tts/TTS/f5-tts/model/README.md new file mode 100644 index 000000000..e4a7e2a7c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/README.md @@ -0,0 +1,3 @@ +# Introduction +Files in this folder are copied from +https://github.com/SWivid/F5-TTS/tree/main/src/f5_tts/model diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py new file mode 100644 index 000000000..349c7220e --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py @@ -0,0 +1,326 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +from random import random +from typing import Callable + +import torch +import torch.nn.functional as F +from model.modules import MelSpec +from model.utils import ( + default, + exists, + lens_to_mask, + list_str_to_idx, + list_str_to_tensor, + mask_from_frac_lengths, +) +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint + + +class CFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + + self.frac_lengths_mask = frac_lengths_mask + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + # raw wave + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + cond = cond.to(next(self.parameters()).dtype) + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + if exists(text): + text_lens = (text != -1).sum(dim=-1) + lens = torch.maximum( + text_lens, lens + ) # make sure lengths are at least those of the text characters + + # duration + + cond_mask = lens_to_mask(lens) + if edit_mask is not None: + cond_mask = cond_mask & edit_mask + + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + lens + 1, duration + ) # just add one token so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + # duplicate test corner for inner time step oberservation + if duplicate_test: + test_cond = F.pad( + cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 + ) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + ) + if cfg_strength < 1e-5: + return pred + + null_pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, + ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append( + torch.randn( + dur, self.num_channels, device=self.device, dtype=step_cond.dtype + ) + ) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + + t_start = 0 + + # duplicate test corner for inner time step oberservation + if duplicate_test: + t_start = t_inter + y0 = (1 - t_start) * y0 + t_start * test_cond + steps = int(steps * (1 - t_start)) + + t = torch.linspace( + t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype + ) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + if exists(vocoder): + out = out.permute(0, 2, 1) + out = vocoder(out) + + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + noise_scheduler: str | None = None, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = inp.permute(0, 2, 1) + assert inp.shape[-1] == self.num_channels + + batch, seq_len, dtype, device, _σ1 = ( + *inp.shape[:2], + inp.dtype, + self.device, + self.sigma, + ) + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch + + # get a random span to mask out for training conditionally + frac_lengths = ( + torch.zeros((batch,), device=self.device) + .float() + .uniform_(*self.frac_lengths_mask) + ) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + + if exists(mask): + rand_span_mask &= mask + + # mel is x1 + x1 = inp + + # x0 is gaussian noise + x0 = torch.randn_like(x1) + + # time step + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + φ = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper + if random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here + # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + pred = self.transformer( + x=φ, + cond=cond, + text=text, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + loss = loss[rand_span_mask] + + return loss.mean(), cond, pred diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py new file mode 100644 index 000000000..966fabfdd --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py @@ -0,0 +1,210 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from model.modules import ( + AdaLayerNormZero_Final, + ConvNeXtV2Block, + ConvPositionEmbedding, + DiTBlock, + TimestepEmbedding, + get_pos_embed_indices, + precompute_freqs_cis, +) +from torch import nn +from x_transformers.x_transformers import RotaryEmbedding + +# Text embedding + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): + super().__init__() + self.text_embed = nn.Embedding( + text_num_embeds + 1, text_dim + ) # use 0 as filler token + + if conv_layers > 0: + self.extra_modeling = True + self.precompute_max_pos = 4096 # ~44s of 24khz audio + self.register_buffer( + "freqs_cis", + precompute_freqs_cis(text_dim, self.precompute_max_pos), + persistent=False, + ) + self.text_blocks = nn.Sequential( + *[ + ConvNeXtV2Block(text_dim, text_dim * conv_mult) + for _ in range(conv_layers) + ] + ) + else: + self.extra_modeling = False + + def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 + text = ( + text + 1 + ) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() + text = text[ + :, :seq_len + ] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] + text = F.pad(text, (0, seq_len - text_len), value=0) + + if drop_text: # cfg for text + text = torch.zeros_like(text) + + text = self.text_embed(text) # b n -> b n d + + # possible extra modeling + if self.extra_modeling: + # sinus pos emb + batch_start = torch.zeros((batch,), dtype=torch.long) + pos_idx = get_pos_embed_indices( + batch_start, seq_len, max_pos=self.precompute_max_pos + ) + text_pos_embed = self.freqs_cis[pos_idx] + text = text + text_pos_embed + + # convnextv2 blocks + text = self.text_blocks(text) + + return text + + +# noised input audio and context mixing embedding + + +class InputEmbedding(nn.Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward( + self, + x: float["b n d"], # noqa: F722 + cond: float["b n d"], # noqa: F722 + text_embed: float["b n d"], # noqa: F722 + drop_audio_cond=False, + ): + if drop_audio_cond: # cfg for cond audio + cond = torch.zeros_like(cond) + + x = self.proj(torch.cat((x, cond, text_embed), dim=-1)) + x = self.conv_pos_embed(x) + x + return x + + +# Transformer backbone using DiT blocks + + +class DiT(nn.Module): + def __init__( + self, + *, + dim, + depth=8, + heads=8, + dim_head=64, + dropout=0.1, + ff_mult=4, + mel_dim=100, + text_num_embeds=256, + text_dim=None, + conv_layers=0, + long_skip_connection=False, + checkpoint_activations=False, + ): + super().__init__() + + self.time_embed = TimestepEmbedding(dim) + if text_dim is None: + text_dim = mel_dim + self.text_embed = TextEmbedding( + text_num_embeds, text_dim, conv_layers=conv_layers + ) + self.input_embed = InputEmbedding(mel_dim, text_dim, dim) + + self.rotary_embed = RotaryEmbedding(dim_head) + + self.dim = dim + self.depth = depth + + self.transformer_blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + heads=heads, + dim_head=dim_head, + ff_mult=ff_mult, + dropout=dropout, + ) + for _ in range(depth) + ] + ) + self.long_skip_connection = ( + nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None + ) + + self.norm_out = AdaLayerNormZero_Final(dim) # final modulation + self.proj_out = nn.Linear(dim, mel_dim) + + self.checkpoint_activations = checkpoint_activations + + def ckpt_wrapper(self, module): + # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + + def forward( + self, + x: float["b n d"], # nosied input audio # noqa: F722 + cond: float["b n d"], # masked cond audio # noqa: F722 + text: int["b nt"], # text # noqa: F722 + time: float["b"] | float[""], # time step # noqa: F821 F722 + drop_audio_cond, # cfg for cond audio + drop_text, # cfg for text + mask: bool["b n"] | None = None, # noqa: F722 + ): + batch, seq_len = x.shape[0], x.shape[1] + if time.ndim == 0: + time = time.repeat(batch) + + # t: conditioning time, c: context (text + masked cond audio), x: noised input audio + t = self.time_embed(time) + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) + + rope = self.rotary_embed.forward_from_seq_len(seq_len) + + if self.long_skip_connection is not None: + residual = x + + for block in self.transformer_blocks: + if self.checkpoint_activations: + x = torch.utils.checkpoint.checkpoint( + self.ckpt_wrapper(block), x, t, mask, rope + ) + else: + x = block(x, t, mask=mask, rope=rope) + + if self.long_skip_connection is not None: + x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) + + x = self.norm_out(x, t) + output = self.proj_out(x) + + return output diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py new file mode 100644 index 000000000..05299d419 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py @@ -0,0 +1,728 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +import torchaudio +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from x_transformers.x_transformers import apply_rotary_pos_emb + +# raw wav to mel spec + + +mel_basis_cache = {} +hann_window_cache = {} + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=target_sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=fmin, + fmax=fmax, + ) + mel_basis_cache[key] = ( + torch.from_numpy(mel).float().to(device) + ) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_length).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" + ) + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if mel_spec_type == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif mel_spec_type == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like( + start, dtype=torch.float32 + ) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor | AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__( + self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False + ): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( + c, emb=t + ) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( + x, emb=t + ) + + # attention + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = ( + self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py new file mode 100644 index 000000000..fae5fadb6 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import os +import random +from collections import defaultdict +from importlib.resources import files + +import jieba +import torch +from pypinyin import Style, lazy_pinyin +from torch.nn.utils.rnn import pad_sequence + +# seed everything + + +def seed_everything(seed=0): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +# helpers + + +def exists(v): + return v is not None + + +def default(v, d): + return v if exists(v) else d + + +# tensor helpers + + +def lens_to_mask( + t: int["b"], length: int | None = None # noqa: F722 F821 +) -> bool["b n"]: # noqa: F722 F821 + if not exists(length): + length = t.amax() + + seq = torch.arange(length, device=t.device) + return seq[None, :] < t[:, None] + + +def mask_from_start_end_indices( + seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 +): + max_seq_len = seq_len.max().item() + seq = torch.arange(max_seq_len, device=start.device).long() + start_mask = seq[None, :] >= start[:, None] + end_mask = seq[None, :] < end[:, None] + return start_mask & end_mask + + +def mask_from_frac_lengths( + seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 +): + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.rand_like(frac_lengths) + start = (max_start * rand).long().clamp(min=0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + + +def maybe_masked_mean( + t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 +) -> float["b d"]: # noqa: F722 F821 + if not exists(mask): + return t.mean(dim=1) + + t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device)) + num = t.sum(dim=1) + den = mask.float().sum(dim=1) + + return num / den.clamp(min=1.0) + + +# simple utf-8 tokenizer, since paper went character based +def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722 + list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style + text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True) + return text + + +# char tokenizer, based on custom dataset's extracted .txt file +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +) -> int["b nt"]: # noqa: F722 + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style + text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return text + + +# Get tokenizer + + +def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + if tokenizer in ["pinyin", "char"]: + tokenizer_path = os.path.join( + files("f5_tts").joinpath("../../data"), + f"{dataset_name}_{tokenizer}/vocab.txt", + ) + with open(tokenizer_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + assert ( + vocab_char_map[" "] == 0 + ), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char" + + elif tokenizer == "byte": + vocab_char_map = None + vocab_size = 256 + + elif tokenizer == "custom": + with open(dataset_name, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +# convert char to pinyin + +jieba.initialize() +print("Word segmentation module jieba initialized.\n") + + +def convert_char_to_pinyin(text_list, polyphone=True): + final_text_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return "\u3100" <= c <= "\u9fff" # common chinese characters + + for text in text_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) + else: + char_list.append(c) + final_text_list.append(char_list) + + return final_text_list + + +# filter func for dirty data with many repetitions + + +def repetition_found(text, length=2, tolerance=10): + pattern_count = defaultdict(int) + for i in range(len(text) - length + 1): + pattern = text[i : i + length] + pattern_count[pattern] += 1 + for pattern, count in pattern_count.items(): + if count > tolerance: + return True + return False diff --git a/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt b/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt new file mode 100644 index 000000000..63f1e237c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt @@ -0,0 +1,36 @@ +# F5-TTS +accelerate>=0.33.0 +bitsandbytes>0.37.0 +cached_path +click +datasets +ema_pytorch>=0.5.2 +gradio>=3.45.2 +hydra-core>=1.3.0 +jieba +librosa +matplotlib +numpy<=1.26.4 +pydub +pypinyin +safetensors +soundfile +tomli +torch>=2.0.0 +torchaudio>=2.0.0 +torchdiffeq +tqdm>=4.65.0 +transformers +x_transformers>=1.31.14 + +# icefall +kaldialign +lhotse +tensorboard +bigvganinference +sentencepiece +sherpa-onnx +k2 + +# semantic experiment +s3tokenizer diff --git a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py new file mode 100644 index 000000000..7d42a00a5 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py @@ -0,0 +1,107 @@ +from typing import Callable, Dict, List, Sequence, Union + +import torch +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.collation import collate_audio +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import ifnone + + +class SpeechSynthesisDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech synthesis task. + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + 'features': (B x NumFrames x NumFeatures) float tensor + 'audio_lens': (B, ) int tensor + 'features_lens': (B, ) int tensor + 'text': List[str] of len B # when return_text=True + 'tokens': List[List[str]] # when return_tokens=True + 'speakers': List[str] of len B # when return_spk_ids=True + 'cut': List of Cuts # when return_cuts=True + } + """ + + def __init__( + self, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + feature_input_strategy: BatchIO = PrecomputedFeatures(), + feature_transforms: Union[Sequence[Callable], Callable] = None, + return_text: bool = True, + return_tokens: bool = False, + return_spk_ids: bool = False, + return_cuts: bool = False, + ) -> None: + super().__init__() + + self.cut_transforms = ifnone(cut_transforms, []) + self.feature_input_strategy = feature_input_strategy + + self.return_text = return_text + self.return_tokens = return_tokens + self.return_spk_ids = return_spk_ids + self.return_cuts = return_cuts + + if feature_transforms is None: + feature_transforms = [] + elif not isinstance(feature_transforms, Sequence): + feature_transforms = [feature_transforms] + + assert all( + isinstance(transform, Callable) for transform in feature_transforms + ), "Feature transforms must be Callable" + self.feature_transforms = feature_transforms + + def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: + validate_for_tts(cuts) + + for transform in self.cut_transforms: + cuts = transform(cuts) + + # audio, audio_lens = collate_audio(cuts) + features, features_lens = self.feature_input_strategy(cuts) + + for transform in self.feature_transforms: + features = transform(features) + + batch = { + # "audio": audio, + "features": features, + # "audio_lens": audio_lens, + "features_lens": features_lens, + } + + if self.return_text: + # use normalized text + # text = [cut.supervisions[0].normalized_text for cut in cuts] + text = [cut.supervisions[0].text for cut in cuts] + batch["text"] = text + + if self.return_tokens and "speech_tokens" in cuts[0].supervisions[0].custom: + # tokens = [cut.tokens for cut in cuts] + # tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] + tokens = [cut.supervisions[0].custom["speech_tokens"] for cut in cuts] + # change str into list + tokens = [list(map(int, token.split())) for token in tokens] + batch["tokens"] = tokens + + if self.return_spk_ids: + batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + + if self.return_cuts: + batch["cut"] = [cut for cut in cuts] + + return batch + + +def validate_for_tts(cuts: CutSet) -> None: + validate(cuts) + for cut in cuts: + assert ( + len(cut.supervisions) == 1 + ), "Only the Cuts with single supervision are supported." diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py new file mode 100755 index 000000000..0cc0bf240 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -0,0 +1,1233 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece + +world_size=8 +exp_dir=exp/f5-tts-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +# command for training with cosyvoice semantic token +exp_dir=exp/f5-tts-cosyvoice +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 1e-4 --warmup-steps 20000 --average-period 0 \ + --num-epochs 10 --start-epoch 1 --start-batch 0 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} \ + --decay-steps 600000 --prefix wenetspeech4tts_cosy_token --use-cosyvoice-semantic-token True +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from contextlib import nullcontext +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model.cfm import CFM +from model.dit import DiT +from model.utils import convert_char_to_pinyin +from torch import Tensor +from torch.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import LinearLR, SequentialLR +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool # MetricsTracker + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=22, + help="Number of Decoder layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="exp/f5", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="f5-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--pretrained-model-path", + type=str, + default=None, + help="Path to file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="AdamW", + help="The optimizer.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--decay-steps", + type=int, + default=1000000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_tokenizer(vocab_file_path: str): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + with open(vocab_file_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +def get_model(params): + if params.use_cosyvoice_semantic_token: + # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 + vocab_char_map, vocab_size = None, 6561 + else: + vocab_char_map, vocab_size = get_tokenizer(params.tokens) + # bigvgan 100 dim features + n_mel_channels = 100 + n_fft = 1024 + sampling_rate = 24_000 + hop_length = 256 + win_length = 1024 + + model_cfg = { + "dim": params.decoder_dim, + "depth": params.num_decoder_layers, + "heads": params.nhead, + "ff_mult": 2, + "text_dim": 512, + "conv_layers": 4, + "checkpoint_activations": False, + } + model = CFM( + transformer=DiT( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), + mel_spec_kwargs=dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=sampling_rate, + mel_spec_type="bigvgan", + ), + odeint_kwargs=dict( + method="euler", + ), + vocab_char_map=vocab_char_map, + ) + return model + + +def load_F5_TTS_pretrained_checkpoint( + model, ckpt_path, device: str = "cpu", dtype=torch.float32 +): + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + if "ema_model_state_dict" in checkpoint: + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + + # patch for backward compatibility, 305e3ea + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: + if key in checkpoint["model_state_dict"]: + del checkpoint["model_state_dict"][key] + model.load_state_dict(checkpoint["model_state_dict"]) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def interpolate_tokens(cosy_tokens, pad_token=-1): + """Interpolate cosyvoice tokens to match bigvgan frames length""" + # cosyvoice, 25 tokens/sec + # bigvgan sample_rate/hop_length 24000/256 frames/sec + # For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length + # We choose 4,4,4,3 to match 15 frames + three, two = [pad_token] * 3, [pad_token] * 2 + return [ + x + for i, e in enumerate(cosy_tokens) + for x in ([e] + three if i % 4 < 3 else [e] + two) + ] + + +def prepare_input( + batch: dict, device: torch.device, use_cosyvoice_semantic_token: bool +): + """Parse batch data""" + mel_spec = batch["features"] + mel_lengths = batch["features_lens"] + + if use_cosyvoice_semantic_token: + semantic_tokens = [] + for i in range(len(batch["tokens"])): + tokens = batch["tokens"][i] + tokens = interpolate_tokens(tokens) + semantic_tokens.append(tokens) + # pad to the same length, B,T, with pad value -1 + max_len = max([len(tokens) for tokens in semantic_tokens]) + text_inputs = torch.full( + (len(semantic_tokens), max_len), -1, dtype=torch.long + ).to(device) + for i, tokens in enumerate(semantic_tokens): + text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + else: + text_inputs = batch["text"] + text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) + + return text_inputs, mel_spec.to(device), mel_lengths.to(device) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + (text_inputs, mel_spec, mel_lengths) = prepare_input( + batch, + device=device, + use_cosyvoice_semantic_token=params.use_cosyvoice_semantic_token, + ) + # at entry, TextTokens is (N, P) + + with torch.set_grad_enabled(is_training): + loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["samples"] = mel_lengths.size(0) + + info["loss"] = loss.detach().cpu().item() * info["samples"] + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + # loss.backward() + # optimizer.step() + + for k in range(params.accumulate_grad_steps): + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.amp.autocast("cuda", dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + tokenizer = get_tokenizer(params.tokens) + logging.info(params) + + logging.info("About to create model") + + model = get_model(params) + + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu", weights_only=False) + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint( + model, params.pretrained_model_path + ) + else: + _ = load_checkpoint( + params.pretrained_model_path, + model=model, + ) + + model = model.to(device) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=False) + + model_parameters = model.parameters() + + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + + warmup_scheduler = LinearLR( + optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps + ) + decay_scheduler = LinearLR( + optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[params.warmup_steps], + ) + + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.valid_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.valid_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler( + "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + print(batch.keys()) + try: + with torch.amp.autocast("cuda", dtype=dtype): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward(retain_graph=True) + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py new file mode 100644 index 000000000..eab7588b7 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -0,0 +1,306 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from speech_synthesis import SpeechSynthesisDataset # noqa F401 +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError( + "On-the-fly feature extraction is not implemented yet." + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/utils.py new file mode 120000 index 000000000..ceaaea196 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt b/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt new file mode 100644 index 000000000..93f8b48b2 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/vocab.txt @@ -0,0 +1,2545 @@ + +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +: +; += +> +? +@ +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +[ +\ +] +_ +a +a1 +ai1 +ai2 +ai3 +ai4 +an1 +an3 +an4 +ang1 +ang2 +ang4 +ao1 +ao2 +ao3 +ao4 +b +ba +ba1 +ba2 +ba3 +ba4 +bai1 +bai2 +bai3 +bai4 +ban1 +ban2 +ban3 +ban4 +bang1 +bang2 +bang3 +bang4 +bao1 +bao2 +bao3 +bao4 +bei +bei1 +bei2 +bei3 +bei4 +ben1 +ben2 +ben3 +ben4 +beng +beng1 +beng2 +beng3 +beng4 +bi1 +bi2 +bi3 +bi4 +bian1 +bian2 +bian3 +bian4 +biao1 +biao2 +biao3 +bie1 +bie2 +bie3 +bie4 +bin1 +bin4 +bing1 +bing2 +bing3 +bing4 +bo +bo1 +bo2 +bo3 +bo4 +bu2 +bu3 +bu4 +c +ca1 +cai1 +cai2 +cai3 +cai4 +can1 +can2 +can3 +can4 +cang1 +cang2 +cao1 +cao2 +cao3 +ce4 +cen1 +cen2 +ceng1 +ceng2 +ceng4 +cha1 +cha2 +cha3 +cha4 +chai1 +chai2 +chan1 +chan2 +chan3 +chan4 +chang1 +chang2 +chang3 +chang4 +chao1 +chao2 +chao3 +che1 +che2 +che3 +che4 +chen1 +chen2 +chen3 +chen4 +cheng1 +cheng2 +cheng3 +cheng4 +chi1 +chi2 +chi3 +chi4 +chong1 +chong2 +chong3 +chong4 +chou1 +chou2 +chou3 +chou4 +chu1 +chu2 +chu3 +chu4 +chua1 +chuai1 +chuai2 +chuai3 +chuai4 +chuan1 +chuan2 +chuan3 +chuan4 +chuang1 +chuang2 +chuang3 +chuang4 +chui1 +chui2 +chun1 +chun2 +chun3 +chuo1 +chuo4 +ci1 +ci2 +ci3 +ci4 +cong1 +cong2 +cou4 +cu1 +cu4 +cuan1 +cuan2 +cuan4 +cui1 +cui3 +cui4 +cun1 +cun2 +cun4 +cuo1 +cuo2 +cuo4 +d +da +da1 +da2 +da3 +da4 +dai1 +dai2 +dai3 +dai4 +dan1 +dan2 +dan3 +dan4 +dang1 +dang2 +dang3 +dang4 +dao1 +dao2 +dao3 +dao4 +de +de1 +de2 +dei3 +den4 +deng1 +deng2 +deng3 +deng4 +di1 +di2 +di3 +di4 +dia3 +dian1 +dian2 +dian3 +dian4 +diao1 +diao3 +diao4 +die1 +die2 +die4 +ding1 +ding2 +ding3 +ding4 +diu1 +dong1 +dong3 +dong4 +dou1 +dou2 +dou3 +dou4 +du1 +du2 +du3 +du4 +duan1 +duan2 +duan3 +duan4 +dui1 +dui4 +dun1 +dun3 +dun4 +duo1 +duo2 +duo3 +duo4 +e +e1 +e2 +e3 +e4 +ei2 +en1 +en4 +er +er2 +er3 +er4 +f +fa1 +fa2 +fa3 +fa4 +fan1 +fan2 +fan3 +fan4 +fang1 +fang2 +fang3 +fang4 +fei1 +fei2 +fei3 +fei4 +fen1 +fen2 +fen3 +fen4 +feng1 +feng2 +feng3 +feng4 +fo2 +fou2 +fou3 +fu1 +fu2 +fu3 +fu4 +g +ga1 +ga2 +ga3 +ga4 +gai1 +gai2 +gai3 +gai4 +gan1 +gan2 +gan3 +gan4 +gang1 +gang2 +gang3 +gang4 +gao1 +gao2 +gao3 +gao4 +ge1 +ge2 +ge3 +ge4 +gei2 +gei3 +gen1 +gen2 +gen3 +gen4 +geng1 +geng3 +geng4 +gong1 +gong3 +gong4 +gou1 +gou2 +gou3 +gou4 +gu +gu1 +gu2 +gu3 +gu4 +gua1 +gua2 +gua3 +gua4 +guai1 +guai2 +guai3 +guai4 +guan1 +guan2 +guan3 +guan4 +guang1 +guang2 +guang3 +guang4 +gui1 +gui2 +gui3 +gui4 +gun3 +gun4 +guo1 +guo2 +guo3 +guo4 +h +ha1 +ha2 +ha3 +hai1 +hai2 +hai3 +hai4 +han1 +han2 +han3 +han4 +hang1 +hang2 +hang4 +hao1 +hao2 +hao3 +hao4 +he1 +he2 +he4 +hei1 +hen2 +hen3 +hen4 +heng1 +heng2 +heng4 +hong1 +hong2 +hong3 +hong4 +hou1 +hou2 +hou3 +hou4 +hu1 +hu2 +hu3 +hu4 +hua1 +hua2 +hua4 +huai2 +huai4 +huan1 +huan2 +huan3 +huan4 +huang1 +huang2 +huang3 +huang4 +hui1 +hui2 +hui3 +hui4 +hun1 +hun2 +hun4 +huo +huo1 +huo2 +huo3 +huo4 +i +j +ji1 +ji2 +ji3 +ji4 +jia +jia1 +jia2 +jia3 +jia4 +jian1 +jian2 +jian3 +jian4 +jiang1 +jiang2 +jiang3 +jiang4 +jiao1 +jiao2 +jiao3 +jiao4 +jie1 +jie2 +jie3 +jie4 +jin1 +jin2 +jin3 +jin4 +jing1 +jing2 +jing3 +jing4 +jiong3 +jiu1 +jiu2 +jiu3 +jiu4 +ju1 +ju2 +ju3 +ju4 +juan1 +juan2 +juan3 +juan4 +jue1 +jue2 +jue4 +jun1 +jun4 +k +ka1 +ka2 +ka3 +kai1 +kai2 +kai3 +kai4 +kan1 +kan2 +kan3 +kan4 +kang1 +kang2 +kang4 +kao1 +kao2 +kao3 +kao4 +ke1 +ke2 +ke3 +ke4 +ken3 +keng1 +kong1 +kong3 +kong4 +kou1 +kou2 +kou3 +kou4 +ku1 +ku2 +ku3 +ku4 +kua1 +kua3 +kua4 +kuai3 +kuai4 +kuan1 +kuan2 +kuan3 +kuang1 +kuang2 +kuang4 +kui1 +kui2 +kui3 +kui4 +kun1 +kun3 +kun4 +kuo4 +l +la +la1 +la2 +la3 +la4 +lai2 +lai4 +lan2 +lan3 +lan4 +lang1 +lang2 +lang3 +lang4 +lao1 +lao2 +lao3 +lao4 +le +le1 +le4 +lei +lei1 +lei2 +lei3 +lei4 +leng1 +leng2 +leng3 +leng4 +li +li1 +li2 +li3 +li4 +lia3 +lian2 +lian3 +lian4 +liang2 +liang3 +liang4 +liao1 +liao2 +liao3 +liao4 +lie1 +lie2 +lie3 +lie4 +lin1 +lin2 +lin3 +lin4 +ling2 +ling3 +ling4 +liu1 +liu2 +liu3 +liu4 +long1 +long2 +long3 +long4 +lou1 +lou2 +lou3 +lou4 +lu1 +lu2 +lu3 +lu4 +luan2 +luan3 +luan4 +lun1 +lun2 +lun4 +luo1 +luo2 +luo3 +luo4 +lv2 +lv3 +lv4 +lve3 +lve4 +m +ma +ma1 +ma2 +ma3 +ma4 +mai2 +mai3 +mai4 +man1 +man2 +man3 +man4 +mang2 +mang3 +mao1 +mao2 +mao3 +mao4 +me +mei2 +mei3 +mei4 +men +men1 +men2 +men4 +meng +meng1 +meng2 +meng3 +meng4 +mi1 +mi2 +mi3 +mi4 +mian2 +mian3 +mian4 +miao1 +miao2 +miao3 +miao4 +mie1 +mie4 +min2 +min3 +ming2 +ming3 +ming4 +miu4 +mo1 +mo2 +mo3 +mo4 +mou1 +mou2 +mou3 +mu2 +mu3 +mu4 +n +n2 +na1 +na2 +na3 +na4 +nai2 +nai3 +nai4 +nan1 +nan2 +nan3 +nan4 +nang1 +nang2 +nang3 +nao1 +nao2 +nao3 +nao4 +ne +ne2 +ne4 +nei3 +nei4 +nen4 +neng2 +ni1 +ni2 +ni3 +ni4 +nian1 +nian2 +nian3 +nian4 +niang2 +niang4 +niao2 +niao3 +niao4 +nie1 +nie4 +nin2 +ning2 +ning3 +ning4 +niu1 +niu2 +niu3 +niu4 +nong2 +nong4 +nou4 +nu2 +nu3 +nu4 +nuan3 +nuo2 +nuo4 +nv2 +nv3 +nve4 +o +o1 +o2 +ou1 +ou2 +ou3 +ou4 +p +pa1 +pa2 +pa4 +pai1 +pai2 +pai3 +pai4 +pan1 +pan2 +pan4 +pang1 +pang2 +pang4 +pao1 +pao2 +pao3 +pao4 +pei1 +pei2 +pei4 +pen1 +pen2 +pen4 +peng1 +peng2 +peng3 +peng4 +pi1 +pi2 +pi3 +pi4 +pian1 +pian2 +pian4 +piao1 +piao2 +piao3 +piao4 +pie1 +pie2 +pie3 +pin1 +pin2 +pin3 +pin4 +ping1 +ping2 +po1 +po2 +po3 +po4 +pou1 +pu1 +pu2 +pu3 +pu4 +q +qi1 +qi2 +qi3 +qi4 +qia1 +qia3 +qia4 +qian1 +qian2 +qian3 +qian4 +qiang1 +qiang2 +qiang3 +qiang4 +qiao1 +qiao2 +qiao3 +qiao4 +qie1 +qie2 +qie3 +qie4 +qin1 +qin2 +qin3 +qin4 +qing1 +qing2 +qing3 +qing4 +qiong1 +qiong2 +qiu1 +qiu2 +qiu3 +qu1 +qu2 +qu3 +qu4 +quan1 +quan2 +quan3 +quan4 +que1 +que2 +que4 +qun2 +r +ran2 +ran3 +rang1 +rang2 +rang3 +rang4 +rao2 +rao3 +rao4 +re2 +re3 +re4 +ren2 +ren3 +ren4 +reng1 +reng2 +ri4 +rong1 +rong2 +rong3 +rou2 +rou4 +ru2 +ru3 +ru4 +ruan2 +ruan3 +rui3 +rui4 +run4 +ruo4 +s +sa1 +sa2 +sa3 +sa4 +sai1 +sai4 +san1 +san2 +san3 +san4 +sang1 +sang3 +sang4 +sao1 +sao2 +sao3 +sao4 +se4 +sen1 +seng1 +sha1 +sha2 +sha3 +sha4 +shai1 +shai2 +shai3 +shai4 +shan1 +shan3 +shan4 +shang +shang1 +shang3 +shang4 +shao1 +shao2 +shao3 +shao4 +she1 +she2 +she3 +she4 +shei2 +shen1 +shen2 +shen3 +shen4 +sheng1 +sheng2 +sheng3 +sheng4 +shi +shi1 +shi2 +shi3 +shi4 +shou1 +shou2 +shou3 +shou4 +shu1 +shu2 +shu3 +shu4 +shua1 +shua2 +shua3 +shua4 +shuai1 +shuai3 +shuai4 +shuan1 +shuan4 +shuang1 +shuang3 +shui2 +shui3 +shui4 +shun3 +shun4 +shuo1 +shuo4 +si1 +si2 +si3 +si4 +song1 +song3 +song4 +sou1 +sou3 +sou4 +su1 +su2 +su4 +suan1 +suan4 +sui1 +sui2 +sui3 +sui4 +sun1 +sun3 +suo +suo1 +suo2 +suo3 +t +ta1 +ta2 +ta3 +ta4 +tai1 +tai2 +tai4 +tan1 +tan2 +tan3 +tan4 +tang1 +tang2 +tang3 +tang4 +tao1 +tao2 +tao3 +tao4 +te4 +teng2 +ti1 +ti2 +ti3 +ti4 +tian1 +tian2 +tian3 +tiao1 +tiao2 +tiao3 +tiao4 +tie1 +tie2 +tie3 +tie4 +ting1 +ting2 +ting3 +tong1 +tong2 +tong3 +tong4 +tou +tou1 +tou2 +tou4 +tu1 +tu2 +tu3 +tu4 +tuan1 +tuan2 +tui1 +tui2 +tui3 +tui4 +tun1 +tun2 +tun4 +tuo1 +tuo2 +tuo3 +tuo4 +u +v +w +wa +wa1 +wa2 +wa3 +wa4 +wai1 +wai3 +wai4 +wan1 +wan2 +wan3 +wan4 +wang1 +wang2 +wang3 +wang4 +wei1 +wei2 +wei3 +wei4 +wen1 +wen2 +wen3 +wen4 +weng1 +weng4 +wo1 +wo2 +wo3 +wo4 +wu1 +wu2 +wu3 +wu4 +x +xi1 +xi2 +xi3 +xi4 +xia1 +xia2 +xia4 +xian1 +xian2 +xian3 +xian4 +xiang1 +xiang2 +xiang3 +xiang4 +xiao1 +xiao2 +xiao3 +xiao4 +xie1 +xie2 +xie3 +xie4 +xin1 +xin2 +xin4 +xing1 +xing2 +xing3 +xing4 +xiong1 +xiong2 +xiu1 +xiu3 +xiu4 +xu +xu1 +xu2 +xu3 +xu4 +xuan1 +xuan2 +xuan3 +xuan4 +xue1 +xue2 +xue3 +xue4 +xun1 +xun2 +xun4 +y +ya +ya1 +ya2 +ya3 +ya4 +yan1 +yan2 +yan3 +yan4 +yang1 +yang2 +yang3 +yang4 +yao1 +yao2 +yao3 +yao4 +ye1 +ye2 +ye3 +ye4 +yi +yi1 +yi2 +yi3 +yi4 +yin1 +yin2 +yin3 +yin4 +ying1 +ying2 +ying3 +ying4 +yo1 +yong1 +yong2 +yong3 +yong4 +you1 +you2 +you3 +you4 +yu1 +yu2 +yu3 +yu4 +yuan1 +yuan2 +yuan3 +yuan4 +yue1 +yue4 +yun1 +yun2 +yun3 +yun4 +z +za1 +za2 +za3 +zai1 +zai3 +zai4 +zan1 +zan2 +zan3 +zan4 +zang1 +zang4 +zao1 +zao2 +zao3 +zao4 +ze2 +ze4 +zei2 +zen3 +zeng1 +zeng4 +zha1 +zha2 +zha3 +zha4 +zhai1 +zhai2 +zhai3 +zhai4 +zhan1 +zhan2 +zhan3 +zhan4 +zhang1 +zhang2 +zhang3 +zhang4 +zhao1 +zhao2 +zhao3 +zhao4 +zhe +zhe1 +zhe2 +zhe3 +zhe4 +zhen1 +zhen2 +zhen3 +zhen4 +zheng1 +zheng2 +zheng3 +zheng4 +zhi1 +zhi2 +zhi3 +zhi4 +zhong1 +zhong2 +zhong3 +zhong4 +zhou1 +zhou2 +zhou3 +zhou4 +zhu1 +zhu2 +zhu3 +zhu4 +zhua1 +zhua2 +zhua3 +zhuai1 +zhuai3 +zhuai4 +zhuan1 +zhuan2 +zhuan3 +zhuan4 +zhuang1 +zhuang4 +zhui1 +zhui4 +zhun1 +zhun2 +zhun3 +zhuo1 +zhuo2 +zi +zi1 +zi2 +zi3 +zi4 +zong1 +zong2 +zong3 +zong4 +zou1 +zou2 +zou3 +zou4 +zu1 +zu2 +zu3 +zuan1 +zuan3 +zuan4 +zui2 +zui3 +zui4 +zun1 +zuo +zuo1 +zuo2 +zuo3 +zuo4 +{ +~ +¡ +¢ +£ +¥ +§ +¨ +© +« +® +¯ +° +± +² +³ +´ +µ +· +¹ +º +» +¼ +½ +¾ +¿ +À +Á + +à +Ä +Å +Æ +Ç +È +É +Ê +Í +Î +Ñ +Ó +Ö +× +Ø +Ú +Ü +Ý +Þ +ß +à +á +â +ã +ä +å +æ +ç +è +é +ê +ë +ì +í +î +ï +ð +ñ +ò +ó +ô +õ +ö +ø +ù +ú +û +ü +ý +Ā +ā +ă +ą +ć +Č +č +Đ +đ +ē +ė +ę +ě +ĝ +ğ +ħ +ī +į +İ +ı +Ł +ł +ń +ņ +ň +ŋ +Ō +ō +ő +œ +ř +Ś +ś +Ş +ş +Š +š +Ť +ť +ũ +ū +ź +Ż +ż +Ž +ž +ơ +ư +ǎ +ǐ +ǒ +ǔ +ǚ +ș +ț +ɑ +ɔ +ɕ +ə +ɛ +ɜ +ɡ +ɣ +ɪ +ɫ +ɴ +ɹ +ɾ +ʃ +ʊ +ʌ +ʒ +ʔ +ʰ +ʷ +ʻ +ʾ +ʿ +ˈ +ː +˙ +˜ +ˢ +́ +̅ +Α +Β +Δ +Ε +Θ +Κ +Λ +Μ +Ξ +Π +Σ +Τ +Φ +Χ +Ψ +Ω +ά +έ +ή +ί +α +β +γ +δ +ε +ζ +η +θ +ι +κ +λ +μ +ν +ξ +ο +π +ρ +ς +σ +τ +υ +φ +χ +ψ +ω +ϊ +ό +ύ +ώ +ϕ +ϵ +Ё +А +Б +В +Г +Д +Е +Ж +З +И +Й +К +Л +М +Н +О +П +Р +С +Т +У +Ф +Х +Ц +Ч +Ш +Щ +Ы +Ь +Э +Ю +Я +а +б +в +г +д +е +ж +з +и +й +к +л +м +н +о +п +р +с +т +у +ф +х +ц +ч +ш +щ +ъ +ы +ь +э +ю +я +ё +і +ְ +ִ +ֵ +ֶ +ַ +ָ +ֹ +ּ +־ +ׁ +א +ב +ג +ד +ה +ו +ז +ח +ט +י +כ +ל +ם +מ +ן +נ +ס +ע +פ +ק +ר +ש +ת +أ +ب +ة +ت +ج +ح +د +ر +ز +س +ص +ط +ع +ق +ك +ل +م +ن +ه +و +ي +َ +ُ +ِ +ْ +ก +ข +ง +จ +ต +ท +น +ป +ย +ร +ว +ส +ห +อ +ฮ +ั +า +ี +ึ +โ +ใ +ไ +่ +้ +์ +ḍ +Ḥ +ḥ +ṁ +ṃ +ṅ +ṇ +Ṛ +ṛ +Ṣ +ṣ +Ṭ +ṭ +ạ +ả +Ấ +ấ +ầ +ậ +ắ +ằ +ẻ +ẽ +ế +ề +ể +ễ +ệ +ị +ọ +ỏ +ố +ồ +ộ +ớ +ờ +ở +ụ +ủ +ứ +ữ +ἀ +ἁ +Ἀ +ἐ +ἔ +ἰ +ἱ +ὀ +ὁ +ὐ +ὲ +ὸ +ᾶ +᾽ +ῆ +ῇ +ῶ +‎ +‑ +‒ +– +— +― +‖ +† +‡ +• +… +‧ +‬ +′ +″ +⁄ +⁡ +⁰ +⁴ +⁵ +⁶ +⁷ +⁸ +⁹ +₁ +₂ +₃ +€ +₱ +₹ +₽ +℃ +ℏ +ℓ +№ +ℝ +™ +⅓ +⅔ +⅛ +→ +∂ +∈ +∑ +− +∗ +√ +∞ +∫ +≈ +≠ +≡ +≤ +≥ +⋅ +⋯ +█ +♪ +⟨ +⟩ +、 +。 +《 +》 +「 +」 +【 +】 +あ +う +え +お +か +が +き +ぎ +く +ぐ +け +げ +こ +ご +さ +し +じ +す +ず +せ +ぜ +そ +ぞ +た +だ +ち +っ +つ +で +と +ど +な +に +ね +の +は +ば +ひ +ぶ +へ +べ +ま +み +む +め +も +ゃ +や +ゆ +ょ +よ +ら +り +る +れ +ろ +わ +を +ん +ァ +ア +ィ +イ +ウ +ェ +エ +オ +カ +ガ +キ +ク +ケ +ゲ +コ +ゴ +サ +ザ +シ +ジ +ス +ズ +セ +ゾ +タ +ダ +チ +ッ +ツ +テ +デ +ト +ド +ナ +ニ +ネ +ノ +バ +パ +ビ +ピ +フ +プ +ヘ +ベ +ペ +ホ +ボ +ポ +マ +ミ +ム +メ +モ +ャ +ヤ +ュ +ユ +ョ +ヨ +ラ +リ +ル +レ +ロ +ワ +ン +・ +ー +ㄋ +ㄍ +ㄎ +ㄏ +ㄓ +ㄕ +ㄚ +ㄜ +ㄟ +ㄤ +ㄥ +ㄧ +ㄱ +ㄴ +ㄷ +ㄹ +ㅁ +ㅂ +ㅅ +ㅈ +ㅍ +ㅎ +ㅏ +ㅓ +ㅗ +ㅜ +ㅡ +ㅣ +㗎 +가 +각 +간 +갈 +감 +갑 +갓 +갔 +강 +같 +개 +거 +건 +걸 +겁 +것 +겉 +게 +겠 +겨 +결 +겼 +경 +계 +고 +곤 +골 +곱 +공 +과 +관 +광 +교 +구 +국 +굴 +귀 +귄 +그 +근 +글 +금 +기 +긴 +길 +까 +깍 +깔 +깜 +깨 +께 +꼬 +꼭 +꽃 +꾸 +꿔 +끔 +끗 +끝 +끼 +나 +난 +날 +남 +납 +내 +냐 +냥 +너 +넘 +넣 +네 +녁 +년 +녕 +노 +녹 +놀 +누 +눈 +느 +는 +늘 +니 +님 +닙 +다 +닥 +단 +달 +닭 +당 +대 +더 +덕 +던 +덥 +데 +도 +독 +동 +돼 +됐 +되 +된 +될 +두 +둑 +둥 +드 +들 +등 +디 +따 +딱 +딸 +땅 +때 +떤 +떨 +떻 +또 +똑 +뚱 +뛰 +뜻 +띠 +라 +락 +란 +람 +랍 +랑 +래 +랜 +러 +런 +럼 +렇 +레 +려 +력 +렵 +렸 +로 +록 +롬 +루 +르 +른 +를 +름 +릉 +리 +릴 +림 +마 +막 +만 +많 +말 +맑 +맙 +맛 +매 +머 +먹 +멍 +메 +면 +명 +몇 +모 +목 +몸 +못 +무 +문 +물 +뭐 +뭘 +미 +민 +밌 +밑 +바 +박 +밖 +반 +받 +발 +밤 +밥 +방 +배 +백 +밸 +뱀 +버 +번 +벌 +벚 +베 +벼 +벽 +별 +병 +보 +복 +본 +볼 +봐 +봤 +부 +분 +불 +비 +빔 +빛 +빠 +빨 +뼈 +뽀 +뿅 +쁘 +사 +산 +살 +삼 +샀 +상 +새 +색 +생 +서 +선 +설 +섭 +섰 +성 +세 +셔 +션 +셨 +소 +속 +손 +송 +수 +숙 +순 +술 +숫 +숭 +숲 +쉬 +쉽 +스 +슨 +습 +슷 +시 +식 +신 +실 +싫 +심 +십 +싶 +싸 +써 +쓰 +쓴 +씌 +씨 +씩 +씬 +아 +악 +안 +않 +알 +야 +약 +얀 +양 +얘 +어 +언 +얼 +엄 +업 +없 +었 +엉 +에 +여 +역 +연 +염 +엽 +영 +옆 +예 +옛 +오 +온 +올 +옷 +옹 +와 +왔 +왜 +요 +욕 +용 +우 +운 +울 +웃 +워 +원 +월 +웠 +위 +윙 +유 +육 +윤 +으 +은 +을 +음 +응 +의 +이 +익 +인 +일 +읽 +임 +입 +있 +자 +작 +잔 +잖 +잘 +잡 +잤 +장 +재 +저 +전 +점 +정 +제 +져 +졌 +조 +족 +좀 +종 +좋 +죠 +주 +준 +줄 +중 +줘 +즈 +즐 +즘 +지 +진 +집 +짜 +짝 +쩌 +쪼 +쪽 +쫌 +쭈 +쯔 +찌 +찍 +차 +착 +찾 +책 +처 +천 +철 +체 +쳐 +쳤 +초 +촌 +추 +출 +춤 +춥 +춰 +치 +친 +칠 +침 +칩 +칼 +커 +켓 +코 +콩 +쿠 +퀴 +크 +큰 +큽 +키 +킨 +타 +태 +터 +턴 +털 +테 +토 +통 +투 +트 +특 +튼 +틀 +티 +팀 +파 +팔 +패 +페 +펜 +펭 +평 +포 +폭 +표 +품 +풍 +프 +플 +피 +필 +하 +학 +한 +할 +함 +합 +항 +해 +햇 +했 +행 +허 +험 +형 +혜 +호 +혼 +홀 +화 +회 +획 +후 +휴 +흐 +흔 +희 +히 +힘 +ﷺ +ﷻ +! +, +? +� +𠮶 diff --git a/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py b/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py new file mode 100644 index 000000000..9904901f0 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/attach_speech_tokens.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2025 author: Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gzip +import json +import logging + +import s3tokenizer +from lhotse import CutSet, load_manifest_lazy +from tqdm import tqdm + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/fbank", + help="Directory to store the manifest files", + ) + + parser.add_argument( + "--jsonl-prefix", + type=str, + default="wenetspeech4tts_cuts_valid", + help="The training subset for wenetspeech.", + ) + + parser.add_argument( + "--tokens-path", + type=str, + default="./s3_tokens_valid/wenetspeech4tts_valid.json", + help="json file containing the speech tokens", + ) + + return parser + + +def get_speech_tokens(tokens_path): + id2tokens = {} + with open(tokens_path, "r") as fin: + for line in fin: + line = json.loads(line) + id2tokens[line["key"]] = " ".join(map(str, line["code"])) + return id2tokens + + +def attach_manifest(manifest, fixed_manifest_path, id2tokens): + with CutSet.open_writer(fixed_manifest_path) as manifest_writer: + fixed_item = 0 + for i, cut in enumerate(tqdm(manifest)): + cut_id = cut.supervisions[0].id + if cut_id in id2tokens: + code = id2tokens[cut_id] + cut.supervisions[0].custom = { + **cut.supervisions[0].custom, + **{"speech_tokens": code}, + } + else: + print(f"cut_id {cut_id} not in id2tokens") + fixed_item += 1 + manifest_writer.write(cut) + logging.info(f"Fixed {fixed_item} items in the manifest") + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + manifest_path = args.manifest_dir + "/" + f"{args.jsonl_prefix}.jsonl.gz" + attached_manifest_path = ( + args.manifest_dir + "/" + f"{args.jsonl_prefix}_attached_cosyvoice_v2.jsonl.gz" + ) + logging.info(f"Loading manifest from {manifest_path}") + cuts_manifest = load_manifest_lazy(manifest_path) + logging.info(f"Loading manifest from {manifest_path} done") + id2tokens = get_speech_tokens(args.tokens_path) + logging.info(f"Loaded id2tokens with {len(id2tokens)} entries") + + attach_manifest(cuts_manifest, attached_manifest_path, id2tokens) + logging.info( + f"Manifest with speech tokens attached is saved to {attached_manifest_path}" + ) + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/local/audio.py b/egs/wenetspeech4tts/TTS/local/audio.py new file mode 100644 index 000000000..b643e3de0 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/audio.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import pathlib +import random +from typing import List, Optional, Tuple + +import librosa +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from tqdm import tqdm + +# from env import AttrDict + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int = 1024, + num_mels: int = 100, + sampling_rate: int = 24_000, + hop_size: int = 256, + win_size: int = 1024, + fmin: int = 0, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad( + y.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec diff --git a/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py new file mode 100755 index 000000000..5292c75ad --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=1, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/fbank"), + help="Path to the tokenized files", + ) + + parser.add_argument( + "--dataset-parts", + type=str, + default="Basic", + help="Space separated dataset parts", + ) + + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + + parser.add_argument( + "--split", + type=int, + default=100, + help="Split the cut_set into multiple parts", + ) + + parser.add_argument( + "--resample-to-24kHz", + default=True, + help="Resample the audio to 24kHz", + ) + + parser.add_argument( + "--extractor", + type=str, + choices=["bigvgan", "hifigan"], + default="bigvgan", + help="The type of extractor to use", + ) + return parser + + +def compute_fbank(args): + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + num_jobs = min(args.num_jobs, os.cpu_count()) + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ") + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + logging.info(f"dataset_parts: {dataset_parts}") + if args.extractor == "bigvgan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=100, + sampling_rate=24_000, + hop_length=256, + win_length=1024, + f_min=0, + f_max=None, + ) + elif args.extractor == "hifigan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + else: + raise NotImplementedError(f"Extractor {args.extractor} is not implemented") + + extractor = MatchaFbank(config) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + if args.split > 1: + cut_sets = cut_set.split(args.split) + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}" + ) + + if args.resample_to_24kHz: + part = part.resample(24000) + + with torch.no_grad(): + part = part.compute_and_store_features( + extractor=extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + if args.split > 1: + cuts_filename = ( + f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}" + ) + else: + cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}" + + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank(args) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 100755 index 000000000..7de2c6202 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Phonemize Text and EnCodec Audio. + +Usage example: + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + +""" +import argparse +import logging +import os +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.multiprocessing +from encodec import EncodecModel +from encodec.utils import convert_audio +from lhotse import CutSet, NumpyHdf5Writer +from lhotse.features import FeatureExtractor +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import Seconds, compute_num_frames +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator +from tqdm.auto import tqdm + +from icefall.utils import get_executor + +try: + from pypinyin import Style, pinyin + from pypinyin.style._utils import get_finals, get_initials +except Exception: + pass + + +import re +from typing import Pattern + +import numpy as np +from k2 import SymbolTable + +# from valle.data import ( +# AudioTokenConfig, +# AudioTokenExtractor, +# TextTokenizer, +# tokenize_text, +# ) +# from valle.data.fbank import get_fbank_extractor +# from valle.utils import SymbolTable + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized files", + ) + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + parser.add_argument( + "--audio-extractor", + type=str, + default="Encodec", + help="Encodec or Fbank", + ) + parser.add_argument( + "--dataset-parts", + type=str, + default="dev-clean test-clean", + help="Space separated dataset parts", + ) + parser.add_argument( + "--prefix", + type=str, + default="libritts", + help="prefix of the manifest file", + ) + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=400.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + parser.add_argument( + "--split", + type=int, + default=1, + help="Split the cut_set into multiple parts", + ) + + return parser.parse_args() + + +class PypinyinBackend: + """PypinyinBackend for Chinese. Most codes is referenced from espnet. + There are two types pinyin or initials_finals, one is + just like "ni1 hao3", the other is like "n i1 h ao3". + """ + + def __init__( + self, + backend="initials_finals", + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + ) -> None: + self.backend = backend + self.punctuation_marks = punctuation_marks + + def phonemize( + self, text: List[str], separator: Separator, strip=True, njobs=1 + ) -> List[str]: + assert isinstance(text, List) + phonemized = [] + for _text in text: + _text = re.sub(" +", " ", _text.strip()) + _text = _text.replace(" ", separator.word) + phones = [] + if self.backend == "pypinyin": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + + phones.extend(list(py[0])) + else: + phones.extend([py[0], separator.syllable]) + elif self.backend == "pypinyin_initials_finals": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + phones.extend(list(py[0])) + else: + if py[0][-1].isalnum(): + initial = get_initials(py[0], strict=False) + if py[0][-1].isdigit(): + final = get_finals(py[0][:-1], strict=False) + py[0][-1] + else: + final = get_finals(py[0], strict=False) + phones.extend( + [ + initial, + separator.phone, + final, + separator.syllable, + ] + ) + else: + assert ValueError + else: + raise NotImplementedError + phonemized.append( + "".join(phones).rstrip(f"{separator.word}{separator.syllable}") + ) + return phonemized + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + if backend == "espeak": + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + elif backend in ["pypinyin", "pypinyin_initials_finals"]: + phonemizer = PypinyinBackend( + backend=backend, + punctuation_marks=punctuation_marks + separator.word, + ) + else: + raise NotImplementedError(f"{backend}") + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: + phonemes = tokenizer([text.strip()]) + return phonemes[0] # k2symbols + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +class AudioTokenizer: + """EnCodec audio.""" + + def __init__( + self, + device: Any = None, + ) -> None: + # Instantiate a pretrained EnCodec model + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + return self.codec.decode(frames) + + +@dataclass +class AudioTokenConfig: + frame_shift: Seconds = 320.0 / 24000 + num_quantizers: int = 8 + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": + return AudioTokenConfig(**data) + + +class AudioTokenExtractor(FeatureExtractor): + name = "encodec" + config_type = AudioTokenConfig + + def __init__(self, config: Optional[Any] = None): + super(AudioTokenExtractor, self).__init__(config) + self.tokenizer = AudioTokenizer() + + def extract( + self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int + ) -> np.ndarray: + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if sampling_rate != self.tokenizer.sample_rate: + samples = convert_audio( + samples, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + if len(samples.shape) == 2: + samples = samples.unsqueeze(0) + else: + raise ValueError() + + device = self.tokenizer.device + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + codes = encoded_frames[0][0] # [B, n_q, T] + if True: + duration = round(samples.shape[-1] / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + assert abs(codes.shape[-1] - expected_num_frames) <= 1 + codes = codes[..., :expected_num_frames] + return codes.cpu().squeeze(0).permute(1, 0).numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.frame_shift + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.num_quantizers + + def pad_tensor_list(self, tensor_list, device, padding_value=0): + lengths = [tensor.shape[0] for tensor in tensor_list] + tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] + padded_tensor = torch.nn.utils.rnn.pad_sequence( + tensor_list, batch_first=True, padding_value=padding_value + ) + return padded_tensor, lengths + + def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: + samples = [wav.squeeze() for wav in samples] + device = self.tokenizer.device + samples, lengths = self.pad_tensor_list(samples, device) + samples = samples.unsqueeze(1) + + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if len(samples.shape) != 3: + raise ValueError() + if sampling_rate != self.tokenizer.sample_rate: + samples = [ + convert_audio( + wav, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + for wav in samples + ] + samples = torch.stack(samples, 0) # convert samples from list to tensor + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + encoded_frames = encoded_frames[0][0] # [B, n_q, T] + batch_codes = [] + for b, length in enumerate(lengths): + codes = encoded_frames[b] + duration = round(length / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + batch_codes.append(codes[..., :expected_num_frames]) + return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] + + +def main(): + args = get_args() + + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() + if dataset_parts == "all": # LibriTTS + dataset_parts = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ] + else: + dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") + + assert len(dataset_parts) >= 1 + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + text_tokenizer = None + if args.text_extractor: + text_tokenizer = TextTokenizer(backend=args.text_extractor) + + audio_extractor = None + if args.audio_extractor: + if args.audio_extractor == "Encodec": + audio_extractor = AudioTokenExtractor(AudioTokenConfig()) + else: + raise NotImplementedError(f"{args.audio_extractor}") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + unique_symbols = set() + num_jobs = min(32, os.cpu_count()) + logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") + + prefix = args.prefix + if prefix and not prefix.endswith("_"): + prefix = f"{prefix}_" + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + # Split cut_set if split > 1 + split = 1 + if args.split > 1: + cut_sets = cut_set.split(args.split) + split = args.split + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.audio_extractor: + if args.audio_extractor == "Encodec": + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}" + ) + else: + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}" + ) + + if args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + ]: + part = part.resample(24000) + assert args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + "libritts", + "libritts-r", + ] + with torch.no_grad(): + if ( + torch.cuda.is_available() + and args.audio_extractor == "Encodec" + ): + part = part.compute_and_store_features_batch( + extractor=audio_extractor, + storage_path=storage_path, + num_workers=num_jobs, + batch_duration=args.batch_duration, + collate=False, + overwrite=True, + storage_type=NumpyHdf5Writer, + ) + else: + part = part.compute_and_store_features( + extractor=audio_extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=NumpyHdf5Writer, + ) + + # TextTokenizer + if args.text_extractor: + for c in tqdm(part): + if args.prefix == "ljspeech": + text = c.supervisions[0].custom["normalized_text"] + text = text.replace(""", '"').replace(""", '"') + phonemes = tokenize_text(text_tokenizer, text=text) + elif args.prefix in [ + "aishell", + "aishell2", + "wenetspeech4tts", + "libritts", + "libritts-r", + ]: + phonemes = tokenize_text( + text_tokenizer, text=c.supervisions[0].text + ) + if c.supervisions[0].custom is None: + c.supervisions[0].custom = {} + c.supervisions[0].normalized_text = c.supervisions[0].text + else: + raise NotImplementedError(f"{args.prefix}") + unique_symbols.update(phonemes) + c.tokens = phonemes + assert c.supervisions[ + 0 + ].normalized_text, "normalized_text is None" + + # Save each part with an index if split > 1 + if split > 1: + cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}" + else: + cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" + + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + if args.text_extractor: + unique_phonemes = SymbolTable() + for s in sorted(list(unique_symbols)): + unique_phonemes.add(s) + logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") + + unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" + unique_phonemes.to_file(unique_phonemes_file) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/local/compute_wer.sh b/egs/wenetspeech4tts/TTS/local/compute_wer.sh new file mode 100644 index 000000000..283546383 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_wer.sh @@ -0,0 +1,26 @@ +wav_dir=$1 +wav_files=$(ls $wav_dir/*.wav) +# if wav_files is empty, then exit +if [ -z "$wav_files" ]; then + exit 1 +fi +label_file=$2 +model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 + +if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local +fi + +python3 local/offline-decode-files.py \ + --tokens=$model_path/tokens.txt \ + --paraformer=$model_path/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=24000 \ + --log-dir $wav_dir \ + --feature-dim=80 \ + --label $label_file \ + $wav_files diff --git a/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..f967dfd2b --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 (authors: Feiteng Li) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in the manifests. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + +import argparse +from pathlib import Path + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized manifests.", + ) + return parser.parse_args() + + +def main(): + args = get_args() + manifest_dir = args.manifest_dir or Path("data/tokenized") + for part in ["train", "dev", "test"]: + print(f"## {part}") + cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") + cuts.describe() + print("\n") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/local/fbank.py b/egs/wenetspeech4tts/TTS/local/fbank.py new file mode 120000 index 000000000..3cfb7fe3f --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/local/offline-decode-files.py b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py new file mode 100755 index 000000000..fa6cbdb3e --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/offline-decode-files.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_onnx +import soundfile as sf + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + parser.add_argument( + "--name", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--log-dir", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--label", + type=str, + default=None, + help="wav_base_name label", + ) + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + + +def normalize_text_alimeeting(text: str) -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + import re + + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + return text + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + + print("Started!") + start_time = time.time() + + streams, results = [], [] + total_duration = 0 + + for i, wave_filename in enumerate(args.sound_files): + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + if i % 10 == 0: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + streams = [] + print(f"Processed {i} files") + # process the last batch + if streams: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + results_dict = {} + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + wave_basename = Path(wave_filename).stem + results_dict[wave_basename] = result + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + if args.label: + from icefall.utils import store_transcripts, write_error_stats + + labels_dict = {} + with open(args.label, "r") as f: + for line in f: + # fields = line.strip().split(" ") + # fields = [item for item in fields if item] + # assert len(fields) == 4 + # prompt_text, prompt_audio, text, audio_path = fields + + fields = line.strip().split("|") + fields = [item for item in fields if item] + assert len(fields) == 4 + audio_path, prompt_text, prompt_audio, text = fields + labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) + + final_results = [] + for key, value in results_dict.items(): + final_results.append((key, labels_dict[key], value)) + + store_transcripts( + filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results + ) + with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: + write_error_stats(f, "test-set", final_results, enable_log=True) + + with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh new file mode 100755 index 000000000..f1daa0e62 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -0,0 +1,165 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +stage=1 +stop_stage=4 + +dl_dir=$PWD/download + +dataset_parts="Premium" # Basic for all 7226 hours data, Premium for 945 hours subset. + +text_extractor="pypinyin_initials_finals" # default is espeak for English +audio_extractor="Encodec" # or Fbank +audio_feats_dir=data/tokenized + +. shared/parse_options.sh || exit 1 + + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "dl_dir: $dl_dir" + log "Stage 0: Download data" + huggingface-cli login + huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS + + # Extract the downloaded data: + for folder in Standard Premium Basic; do + for file in "$dl_dir/$folder"/*.tar.gz; do + tar -xzvf "$file" -C "$dl_dir/$folder" + done + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare wenetspeech4tts manifest" + # We assume that you have downloaded the wenetspeech4tts corpus + # to $dl_dir/wenetspeech4tts + mkdir -p data/manifests + if [ ! -e data/manifests/.wenetspeech4tts.done ]; then + lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}" + touch data/manifests/.wenetspeech4tts.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Tokenize/Fbank wenetspeech4tts" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" \ + --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir} + fi + touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Combine features" + if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then + pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz") + lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare wenetspeech4tts train/dev/test" + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then + + lhotse subset --first 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + + lhotse subset --last 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz + + lhotse copy \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + + touch ${audio_feats_dir}/.wenetspeech4tts.train.done + fi + python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} +fi + +subset="Basic" +prefix="wenetspeech4tts" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./f5-tts)" + mkdir -p data/fbank + if [ ! -e data/fbank/.${prefix}.done ]; then + ./local/compute_mel_feat.py --dataset-parts $subset --split 100 + touch data/fbank/.${prefix}.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" + if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then + echo "Combining ${prefix} cuts" + pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz + fi + if [ ! -e data/fbank/.${prefix}_split.done ]; then + echo "Splitting ${prefix} cuts into train, valid and test sets" + + lhotse subset --last 800 \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz + lhotse subset --first 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_valid.jsonl.gz + lhotse subset --last 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_test.jsonl.gz + + rm data/fbank/${prefix}_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) + lhotse subset --first $n \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_train.jsonl.gz + touch data/fbank/.${prefix}_split.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Extract cosyvoice2 FSQ token (used by ./f5-tts semantic token experiment)" + split_name=("valid" "test" "train") + for split in "${split_name[@]}"; do + echo "Processing $split" + wav_scp_file=wav_${split}.scp + output_dir="./cosy_v2_tokens_${split}" + oringinal_jsonl_file=data/fbank/${prefix}_cuts_${split}.jsonl.gz + mkdir -p $output_dir + zcat $oringinal_jsonl_file | jq -r '.recording.id + " " + .recording.sources[0].source' > $wav_scp_file + torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + `which s3tokenizer` --wav_scp $wav_scp_file \ + --device "cuda" \ + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 4 \ + --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz + + cat $output_dir/* > $output_dir/${prefix}_${split}_cosy_v2_tokens.json + python3 local/attach_speech_tokens.py --jsonl-prefix ${prefix}_cuts_${split} --tokens-path $output_dir/${prefix}_${split}_cosy_v2_tokens.json --manifest-dir data/fbank + done +fi diff --git a/egs/wenetspeech4tts/TTS/shared b/egs/wenetspeech4tts/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 000000000..e70ee319a --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py new file mode 100644 index 000000000..1f8f285f8 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to synthesize speech from text prompts and audio prompts. +Usage example: + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ + --checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-prompts "KNOT one point one five miles per hour." \ + --audio-prompts ./prompts/8463_294825_000043_000000.wav \ + --text "To get up and running quickly just follow the steps below." + + top_p=1.0 + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} + +""" +import argparse +import logging +import os +from pathlib import Path + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +import torch +import torchaudio +from compute_neural_codec_and_prepare_text_tokens import ( + AudioTokenizer, + TextTokenizer, + tokenize_text, +) +from encodec.utils import convert_audio +from k2 import symbol_table +from tokenizer import get_text_token_collater +from valle import VALLE + +from icefall.utils import AttributeDict, str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--text-prompts", + type=str, + default="", + help="Text prompts which are separated by |.", + ) + + parser.add_argument( + "--audio-prompts", + type=str, + default="", + help="Audio prompts which are separated by | and should be aligned with --text-prompts.", + ) + + parser.add_argument( + "--text", + type=str, + default="", + help="prompt text\t prompt audio\ttarget text\ttarget audio", + ) + + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + + parser.add_argument( + "--checkpoint", + type=str, + default="./valle/exp/checkpoint-100000.pt", + help="Path to the saved checkpoint.", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("infer/demo"), + help="Path to the tokenized files.", + ) + + parser.add_argument( + "--top-k", + type=int, + default=-100, + help="Whether AR Decoder do top_k(if > 0) sampling.", + ) + + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Whether AR Decoder do top_p(if > 0) sampling.", + ) + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="The temperature of AR Decoder top_k sampling.", + ) + + parser.add_argument( + "--repetition-aware-sampling", + type=str2bool, + default=False, + help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370", + ) + + return parser.parse_args() + + +def load_model(checkpoint, device): + if not checkpoint: + return None + + checkpoint = torch.load(checkpoint, map_location=device, weights_only=False) + + params = AttributeDict(checkpoint) + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model"], strict=True + ) + assert not missing_keys + model.to(device) + model.eval() + + return model, params.text_tokens + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): + # Load and pre-process the audio waveform + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames + + +@torch.no_grad() +def main(): + args = get_args() + text_tokenizer = TextTokenizer(backend=args.text_extractor) + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model, text_tokens = load_model(args.checkpoint, device) + + text_collater = get_text_token_collater(text_tokens) + + audio_tokenizer = AudioTokenizer() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + text_prompts = " ".join(args.text_prompts.split("|")) + + audio_prompts = [] + if args.audio_prompts: + for n, audio_file in enumerate(args.audio_prompts.split("|")): + encoded_frames = tokenize_audio(audio_tokenizer, audio_file) + if False: + samples = audio_tokenizer.decode(encoded_frames) + torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000) + + audio_prompts.append(encoded_frames[0][0]) + + assert len(args.text_prompts.split("|")) == len(audio_prompts) + audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) + audio_prompts = audio_prompts.to(device) + + if os.path.isfile(args.text): # for demos + # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py + with open(args.text) as f: + for line in f: + fields = line.strip().split(" ") + fields = [item for item in fields if item] + assert len(fields) == 4 + prompt_text, prompt_audio, text, audio_path = fields + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{prompt_text} {text}".strip() + ) + ] + ) + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())] + ) + + audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) + audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) + + # synthesis + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + samples = audio_tokenizer.decode( + [(encoded_frames.transpose(2, 1), None)] + ) + # store + # save audio path into args.output_dir + audio_path + audio_path = f"{args.output_dir}/{audio_path}" + # mkdir -p + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + torchaudio.save(audio_path, samples[0].cpu(), 24000) + return + + for n, text in enumerate(args.text.split("|")): + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())] + ) + + # synthesis + enroll_x_lens = None + if text_prompts: + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] + ) + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + if audio_prompts != []: + samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) + # store + torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000) + else: # Transformer + pass + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/valle/optim.py b/egs/wenetspeech4tts/TTS/valle/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/requirements.txt b/egs/wenetspeech4tts/TTS/valle/requirements.txt new file mode 100644 index 000000000..06958dbea --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/requirements.txt @@ -0,0 +1,2 @@ +phonemizer==3.2.1 +git+https://github.com/facebookresearch/encodec.git \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/tokenizer.py b/egs/wenetspeech4tts/TTS/valle/tokenizer.py new file mode 100644 index 000000000..db4f00396 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tokenizer.py @@ -0,0 +1,111 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from k2 import SymbolTable + + +class TextTokenCollater: + """Collate list of text tokens + + Map sentences to integers. Sentences are padded to equal length. + Beginning and end-of-sequence symbols can be added. + + Example: + >>> token_collater = TextTokenCollater(text_tokens) + >>> tokens_batch, tokens_lens = token_collater(text) + + Returns: + tokens_batch: IntTensor of shape (B, L) + B: batch dimension, number of input sentences + L: length of the longest sentence + tokens_lens: IntTensor of shape (B,) + Length of each sentence after adding and + but before padding. + """ + + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + + self.add_eos = add_eos + self.add_bos = add_bos + + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = ( + [pad_symbol] + + ([bos_symbol] if add_bos else []) + + ([eos_symbol] if add_eos else []) + + sorted(text_tokens) + ) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = [token for token in unique_tokens] + + def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + seqs, seq_lens = [], [] + for tokens in tokens_list: + assert all([True if s in self.token2idx else False for s in tokens]) is True + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + list(tokens) + + ([self.eos_symbol] if self.add_eos else []) + ) + seqs.append(seq) + seq_lens.append(len(seq)) + + max_len = max(seq_lens) + for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): + seq.extend([self.pad_symbol] * (max_len - seq_len)) + + tokens = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + tokens_lens = torch.IntTensor(seq_lens) + + return tokens, tokens_lens + + def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_seqs = [[p for p in text] for text in texts] + max_len = len(max(tokens_seqs, key=len)) + + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + list(seq) + + ([self.eos_symbol] if self.add_eos else []) + + [self.pad_symbol] * (max_len - len(seq)) + for seq in tokens_seqs + ] + + tokens_batch = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + + tokens_lens = torch.IntTensor( + [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs] + ) + + return tokens_batch, tokens_lens + + +def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True) + return collater diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py new file mode 100755 index 000000000..e9ec548f3 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -0,0 +1,1243 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +""" + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from tokenizer import TextTokenCollater, get_text_token_collater +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from valle import VALLE + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=12, + help="Number of Decoder layers.", + ) + parser.add_argument( + "--scale-factor", + type=float, + default=1.0, + help="Model scale factor which will be assigned different meanings in different models.", + ) + parser.add_argument( + "--norm-first", + type=str2bool, + default=True, + help="Pre or Post Normalization.", + ) + parser.add_argument( + "--add-prenet", + type=str2bool, + default=False, + help="Whether add PreNet after Inputs.", + ) + + parser.add_argument( + "--prefix-mode", + type=int, + default=0, + help="The mode for how to prefix VALL-E NAR Decoder, " + "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", + ) + parser.add_argument( + "--share-embedding", + type=str2bool, + default=True, + help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", + ) + parser.add_argument( + "--prepend-bos", + type=str2bool, + default=False, + help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", + ) + parser.add_argument( + "--num-quantizers", + type=int, + default=8, + help="Number of Audio/Semantic quantization layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="./valle/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--text-tokens", + type=str, + default="data/tokenized/unique_text_tokens.k2symbols", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="ScaledAdam", + help="The optimizer.", + ) + parser.add_argument( + "--scheduler-name", + type=str, + default="Eden", + help="The scheduler.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--train-stage", + type=int, + default=0, + help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s) + """, + ) + + parser.add_argument( + "--visualize", + type=str2bool, + default=False, + help="visualize model results in eval step.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + saved_stage = saved_params.get("train_stage", 0) + if params.train_stage != saved_stage: + # switch training stage + if params.train_stage and saved_stage: # switch between 1 and 2 + params.start_epoch = 1 + params.start_batch = 0 + else: + # switch between 0 and 1/2 + assert params.num_epochs >= params.start_epoch + params.batch_idx_train = saved_params["batch_idx_train"] + + for key in ["optimizer", "grad_scaler", "sampler"]: + if key in saved_params: + saved_params.pop(key) + + # when base on stage 0, we keep scheduler + if saved_stage != 0: + for key in ["scheduler"]: + if key in saved_params: + saved_params.pop(key) + + best_train_filename = params.exp_dir / "best-train-loss.pt" + if best_train_filename.is_file(): + copyfile( + src=best_train_filename, + dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt", + ) + + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + if best_valid_filename.is_file(): + copyfile( + src=best_valid_filename, + dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt", + ) + else: + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device): + """Parse batch data""" + + features = batch["features"].to(device) + features_lens = batch["features_lens"].to(device) + if "tokens" not in batch: + raise NotImplementedError("Need to tokenize text") + # tokens = [] + # for c in batch["cuts"]: + # phonemes = tokenize_text( + # tokenizer, text=c.supervisions[0].text + # ) + # tokens.append(phonemes) + else: + tokens = batch["tokens"] + + text_tokens, text_tokens_lens = tokenizer(tokens) + text_tokens = text_tokens.to(device) + text_tokens_lens = text_tokens_lens.to(device) + + return features, features_lens, text_tokens, text_tokens_lens + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + ( + audio_features, + audio_features_lens, + text_tokens, + text_tokens_lens, + ) = prepare_input(batch, tokenizer, device) + # at entry, TextTokens is (N, P) + assert text_tokens.ndim == 2 + assert audio_features.ndim == 3 + + with torch.set_grad_enabled(is_training): + predicts, loss, metrics = model( + x=text_tokens, + x_lens=text_tokens_lens, + y=audio_features, + y_lens=audio_features_lens, + train_stage=params.train_stage, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (audio_features_lens).sum().item() + info["utterances"] = text_tokens.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + for metric in metrics: + info[metric] = metrics[metric].detach().cpu().item() + del metrics + + return predicts, loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + predicts, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + if params.visualize: + output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") + output_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, DDP): + model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) + else: + model.visualize(predicts, batch, tokenizer, output_dir=output_dir) + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + _, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + if params.optimizer_name not in ["ScaledAdam", "Eve"]: + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + for k in range(params.accumulate_grad_steps): + if isinstance(scheduler, Eden): + scheduler.step_batch(params.batch_idx_train) + else: + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.cuda.amp.autocast(dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + if params.train_stage: + tb_writer = SummaryWriter( + log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" + ) + else: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + + tokenizer = get_text_token_collater(params.text_tokens) + logging.info(params) + + logging.info("About to create model") + + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.train_stage: + _model = model.module if isinstance(model, DDP) else model + model_parameters = _model.stage_parameters(params.train_stage) + else: + model_parameters = model.parameters() + + if params.optimizer_name == "ScaledAdam": + optimizer = ScaledAdam( + model_parameters, + lr=params.base_lr, + clipping_scale=2.0, + ) + elif params.optimizer_name == "AdamW": + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + elif params.optimizer_name == "Adam": + optimizer = torch.optim.Adam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + eps=1e-8, + ) + else: + raise NotImplementedError() + + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.dev_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.dev_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if isinstance(scheduler, Eden): + scheduler.step_epoch(epoch - 1) + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(dtype=dtype): + _, loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py new file mode 100644 index 000000000..8e34d06dc --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py @@ -0,0 +1,343 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (Author: Yuekai Zhang) +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in TTS + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=4, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--dataset", + type=str, + default="libritts", + help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Audio sampling rate.""", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py new file mode 100644 index 000000000..8f9b8fc3d --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -0,0 +1,1731 @@ +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import numbers +import random +from functools import partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from tokenizer import TextTokenCollater +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter +from torchmetrics.classification import MulticlassAccuracy + +from icefall.utils import make_pad_mask + +NUM_TEXT_TOKENS = 5000 +NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins + + +class PromptedFeatures: + def __init__(self, prompts, features): + self.prompts = prompts + self.features = features + + def to(self, device): + return PromptedFeatures(self.prompts.to(device), self.features.to(device)) + + def sum(self): + return self.features.sum() + + @property + def ndim(self): + return self.features.ndim + + @property + def data(self): + return (self.prompts, self.features) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) + + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + + +_shape_t = Union[int, List[int], torch.Size] + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask if key_padding_mask is not None else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + # elif activation == BalancedDoubleSwish: + # activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + # if layer_norm_cls == IdentityNorm: + # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + # else: + if True: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + ) + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class VALLE(nn.Module): + """It implements https://arxiv.org/abs/2301.02111 + "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + norm_first: bool = True, + add_prenet: bool = False, + decoder_cls=TransformerEncoder, + decoder_layer_cls=TransformerEncoderLayer, + prefix_mode: int = 0, + share_embedding: bool = True, + nar_scale_factor: float = 1.0, + prepend_bos: bool = False, + num_quantizers: int = 8, + **kwargs, + ): + """ + Args: + d_model: + The number of expected features in the input (required). + nhead: + The number of heads in the multiheadattention models (required). + num_layers: + The number of sub-decoder-layers in the decoder (required). + """ + super().__init__() + nar_d_model = int(d_model * nar_scale_factor) + + self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x + self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) + + # ID NUM_AUDIO_TOKENS -> PAD + # ID NUM_AUDIO_TOKENS + 1 -> BOS + self.ar_audio_prepend_bos = prepend_bos + self.ar_audio_embedding = TokenEmbedding( + d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) + ) + + # PreNet + if add_prenet: + self.ar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(d_model, d_model), + ) + + self.ar_audio_prenet = nn.Sequential( + nn.Linear(d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, d_model), + ) + else: + self.ar_text_prenet = nn.Identity() + self.ar_audio_prenet = nn.Identity() + + self.ar_text_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + self.ar_audio_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + + self.ar_decoder = decoder_cls( + decoder_layer_cls( + d_model, + nhead, + dim_feedforward=d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=num_layers, + norm=LayerNorm(d_model) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False) + + self.ar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + self.rng = random.Random(0) + self.num_heads = nhead + self.prefix_mode = prefix_mode + self.num_quantizers = num_quantizers + + assert num_quantizers >= 1 + if num_quantizers > 1: + self.nar_audio_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + + [ + TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) + for i in range(num_quantizers - 1) + ] + ) # W_a + + # PreNet + if add_prenet: + self.nar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(nar_d_model, nar_d_model), + ) + self.nar_audio_prenet = nn.Sequential( + nn.Linear(nar_d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, nar_d_model), + ) + else: + self.nar_text_prenet = nn.Identity() + self.nar_audio_prenet = nn.Identity() + + self.nar_text_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.0, + scale=False, + alpha=False, + ) + self.nar_audio_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.1, + scale=False, + alpha=False, + ) + + self.nar_decoder = decoder_cls( + decoder_layer_cls( + nar_d_model, + int(nhead * nar_scale_factor), + dim_feedforward=nar_d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + adaptive_layer_norm=True, + ), + num_layers=int(num_layers * nar_scale_factor), + norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model)) + if norm_first + else None, + ) + self.nar_predict_layers = nn.ModuleList( + [ + nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) + for i in range(num_quantizers - 1) + ] + ) + self.nar_stage_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)] + ) + + if share_embedding: + # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa + # NOTE(Feiteng): In the experiment, this undermines accuracy + # self.ar_predict_layer.weight = self.ar_audio_embedding.weight + + # We also share the parameters of the acoustic embedding layer and the output prediction layer, + # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. + for j in range(0, num_quantizers - 2): + self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ + j + 2 + ].weight + + self.nar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: + assert stage > 0 + if stage == 1: + for name, param in self.named_parameters(): + if name.startswith("ar_"): + print(f" AR parameter: {name}") + yield param + + if stage == 2: + for name, param in self.named_parameters(): + if name.startswith("nar_"): + print(f"NAR parameter: {name}") + yield param + + def stage_named_parameters( + self, stage: int = 1 + ) -> Iterator[Tuple[str, nn.Parameter]]: + assert stage > 0 + if stage == 1: + for pair in self.named_parameters(): + if pair[0].startswith("ar_"): + yield pair + + if stage == 2: + for pair in self.named_parameters(): + if pair[0].startswith("nar_"): + yield pair + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) + # inputs, targets + if self.ar_audio_prepend_bos: + return ( + F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), + targets, + ) + + return targets[:, :-1], targets[:, 1:] + + def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds + # from the same utterance. + # We implement this differently. + if self.prefix_mode == 0: + # no prefix + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + # Formula (4) (5) + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + elif self.prefix_mode == 1: + # prefix at begining + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + elif self.prefix_mode in [2, 4]: + if self.prefix_mode == 2: + # random prefix + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + else: + raise ValueError + + return y_emb, prefix_len + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: Union[torch.Tensor, PromptedFeatures], + y_lens: Union[torch.Tensor, PromptedFeatures], + reduction: str = "sum", + train_stage: int = 0, + **kwargs, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, T, 8). + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + train_stage: + 0: AR & NAR modules, 1: AR modules, 2: NAR modules + Returns: + Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + + y_prompts_codes = None + if isinstance(y, PromptedFeatures): + y_prompts_codes, y = y.data + prompts_len, y_lens = y_lens.data + assert prompts_len.min() == prompts_len.max() + assert self.prefix_mode == 4 + y_prompts_codes = y_prompts_codes.type(torch.int64) + + assert y.ndim == 3, y.shape + assert y_lens.ndim == 1, y_lens.shape + + # NOTE: x has been padded in TextTokenCollater + x_mask = make_pad_mask(x_lens).to(x.device) + y_mask = make_pad_mask(y_lens).to(y.device) + y_mask_int = y_mask.type(torch.int64) + + text = x + codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) + + y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS) + + x_len = x_lens.max() + + metrics = {} + total_loss = 0.0 + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + if self.ar_audio_prepend_bos: + ar_xy_padding_mask = torch.concat( + [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 + ) + else: + ar_xy_padding_mask = xy_padding_mask + # AR Decoder + if train_stage in [0, 1]: + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + y_len = y_lens.max() + int(self.ar_audio_prepend_bos) + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + # src_key_padding_mask=xy_padding_mask, + # is_causal=True, + ) + logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) + # loss + total_loss = F.cross_entropy(logits, targets, reduction=reduction) + + metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( + logits.detach(), targets + ).item() * y_lens.sum().type(torch.float32) + + if self.num_quantizers == 1: + return ((x, codes), total_loss, metrics) + + # Non-AR Decoders + if self.ar_audio_prepend_bos: + y = y[:, 1:] + if train_stage in [0, 2]: + num_nar_layers = self.num_quantizers - 1 + nar_stage = self.rng.choices( + [_k for _k in range(1, self.num_quantizers)], + weights=[1.0 / num_nar_layers] * num_nar_layers, + k=1, + )[0] + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb, prefix_len = self._prepare_prompts( + y, y_lens, codes, nar_stage, y_prompts_codes + ) + + y_len = y_lens.max() + targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int + if self.prefix_mode in [2, 4]: + xy_padding_mask = torch.concat( + [ + x_mask, + F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), + ], + dim=1, + ) + elif self.prefix_mode == 1: + targets = targets[:, prefix_len:] + + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), + src_key_padding_mask=xy_padding_mask, + # is_causal=False, + ) + xy_dec = xy_dec[:, x_lens.max() + prefix_len :] + if self.prefix_mode == 4: + prefix_len = 0 # reset for Top10Accuracy metric + logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) + + # loss + total_length = (y_lens).sum().type(torch.float32) + total_loss += F.cross_entropy( + logits, + targets, + ignore_index=NUM_AUDIO_TOKENS, + reduction=reduction, + ) * (total_length / (total_length - prefix_len * x.shape[0])) + metrics["NarTop10Accuracy"] = ( + self.nar_accuracy_metric( + F.pad( + logits.detach(), + (0, 0, 0, 1, 0, 0), + value=logits.min().cpu().item(), + ), + targets, + ).item() + * total_length + ) + + if train_stage == 0: + total_loss = total_loss / 2.0 + + return ((x, codes), total_loss, metrics) + + def inference( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + enroll_x_lens: torch.Tensor, + top_k: int = -100, + temperature: float = 1.0, + top_p: float = 1.0, + ras: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + ras: (`optional`) bool + Whether to use repetition-aware sampling. Default to False. + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + prompts = y + prefix_len = y.shape[1] + + # AR Decoder + # TODO: Managing decoder steps avoid repetitive computation + y = prompts[..., 0] + if self.ar_audio_prepend_bos: + y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) + + x_len = x_lens.max() + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + while True: + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + y.device + ) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = topk_sampling( + logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_aware_sampling=ras, + preceding_tokens=y, + ) + + if ( + torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS + or samples[0, 0] == NUM_AUDIO_TOKENS + or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 + ): + if prompts.shape[1] == y.shape[1]: + raise SyntaxError("well trained model shouldn't reach here.") + break + + y = torch.concat([y, samples], dim=1) + + codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] + if self.num_quantizers == 1: + return torch.stack(codes, dim=-1) + + # Non-AR Decoders + y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) + + if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes + enrolled_len = enroll_x_lens.max().item() + # SOS + Synthesis Text + EOS + text = torch.concat( + [ + text[:, :1], + text[:, enrolled_len - 1 :], + ], + dim=1, + ) + text_len = text_len - (enrolled_len - 2) + assert text.shape[0] == 1 + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, self.num_quantizers): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == self.num_quantizers + return torch.stack(codes, dim=-1) + + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + tokenizer: TextTokenCollater, + output_dir: str, + limit: int = 4, + ) -> None: + audio_features = batch["features"].to("cpu").detach().numpy() + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() + + tokens = batch["tokens"] + text_tokens, text_tokens_lens = tokenizer(tokens) + assert text_tokens.ndim == 2 + + texts = batch["text"] + utt_ids = [cut.id for cut in batch["cut"]] + + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() + decoder_outputs = predicts[1] + if isinstance(decoder_outputs, list): + decoder_outputs = decoder_outputs[-1] + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() + + vmin, vmax = 0, 1024 # Encodec + + num_figures = 3 + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): + _ = plt.figure(figsize=(14, 8 * num_figures)) + + S = text_tokens_lens[b] + T = audio_features_lens[b] + + # encoder + plt.subplot(num_figures, 1, 1) + plt.title(f"Text: {text}") + plt.imshow( + X=np.transpose(encoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + ) + plt.gca().invert_yaxis() + plt.axvline(x=S - 0.4, linewidth=2, color="r") + plt.xlabel("Encoder Output") + plt.colorbar() + + # decoder + plt.subplot(num_figures, 1, 2) + plt.imshow( + X=np.transpose(decoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Output") + plt.colorbar() + + # target + plt.subplot(num_figures, 1, 3) + plt.imshow( + X=np.transpose(audio_features[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Target") + plt.colorbar() + + plt.savefig(f"{output_dir}/{utt_id}.png") + plt.close() + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling( + logits, + top_k=10, + top_p=1.0, + temperature=1.0, + repetition_aware_sampling=False, + preceding_tokens=None, +): + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits_filtered = top_k_top_p_filtering( + logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) + # Sample + probs = F.softmax(logits_filtered, dim=-1) + tokens = torch.multinomial(probs, num_samples=1) + + if repetition_aware_sampling: + window_size = 10 + threshold = 0.1 + # we first generate the target code ct′ + # by nucleus sampling with a pre-defined top-p value v. Then, we + # calculate the repetition ratio r of token ct′ + # in the preceding code sequence with a window size K. + # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ + # by + # random sampling from p(ct′ + # |x, c window_size: + preceding_tokens = preceding_tokens[:, -window_size:] + if preceding_tokens.shape[1] > 0: + for i, item in enumerate(preceding_tokens): + # check if the repeat ratio exceeds the threshold + if (item == tokens[i]).sum() / window_size > threshold: + # replace the target code ct′ by random sampling + probs = F.softmax(logits[i], dim=-1) + token_new = torch.multinomial(probs, num_samples=1) + tokens[i] = token_new + return tokens diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index b77f734e3..7e8b50fbe 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -915,7 +915,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py index 2c106c4cb..577ee90f4 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py @@ -236,7 +236,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index e334e690a..375d339ca 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -786,7 +786,7 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) + torch.load(lg_filename, map_location=device, weights_only=False) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py index 6995ff2ff..a3ce5a6c4 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py @@ -247,7 +247,7 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index d24c27326..dd72551d9 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -974,7 +974,16 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index e0a94bf08..3de7136ec 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -47,7 +47,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: max_token_id = max(lexicon.tokens) logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") H = k2.ctc_topo(max_token_id) - L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt", weights_only=False)) logging.info("Loading G.fst.txt") with open("data/lm/G.fst.txt") as f: diff --git a/egs/yesno/ASR/local/prepare_lang.py b/egs/yesno/ASR/local/prepare_lang.py index f7fde7796..29202eeaf 100755 --- a/egs/yesno/ASR/local/prepare_lang.py +++ b/egs/yesno/ASR/local/prepare_lang.py @@ -14,7 +14,7 @@ consisting of words and tokens (i.e., phones) and does the following: 4. Generate L.pt, in k2 format. It can be loaded by - d = torch.load("L.pt") + d = torch.load("L.pt", weights_only=False) lexicon = k2.Fsa.from_dict(d) 5. Generate L_disambig.pt, in k2 format. diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index f520607af..479e195fa 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -271,7 +271,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 6c643c263..2a0879045 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -131,7 +131,9 @@ def main(): model.to(device) logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 968a9e9a8..e6471d2db 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -176,7 +176,9 @@ def main(): model = OnnxModel(params.nn_model) logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index bea520998..d4f3ae39f 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -148,13 +148,15 @@ def main(): num_classes=params.num_classes, ) - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(params.HLG, map_location="cpu", weights_only=False) + ) HLG = HLG.to(device) logging.info("Constructing Fbank computer") diff --git a/icefall/__init__.py b/icefall/__init__.py index b1e4313e9..3077b8162 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -68,6 +68,7 @@ from .utils import ( str2bool, subsequent_chunk_mask, tokenize_by_CJK_char, + tokenize_by_ja_char, write_error_stats, ) diff --git a/icefall/ali.py b/icefall/ali.py index c3e4b2662..63bf79d57 100644 --- a/icefall/ali.py +++ b/icefall/ali.py @@ -59,7 +59,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: - alignments: A dict containing utterances and their corresponding framewise alignment, after subsampling. """ - ali_dict = torch.load(filename) + ali_dict = torch.load(filename, weights_only=False) subsampling_factor = ali_dict["subsampling_factor"] alignments = ali_dict["alignments"] return subsampling_factor, alignments diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce1301..4ab685684 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,6 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -43,7 +42,7 @@ def save_checkpoint( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -102,7 +101,7 @@ def load_checkpoint( model_avg: Optional[nn.Module] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, strict: bool = False, ) -> Dict[str, Any]: @@ -110,7 +109,7 @@ def load_checkpoint( TODO: document it """ logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") + checkpoint = torch.load(filename, map_location="cpu", weights_only=False) if next(iter(checkpoint["model"])).startswith("module."): logging.info("Loading checkpoint saved by DDP") @@ -163,7 +162,7 @@ def average_checkpoints( """ n = len(filenames) - avg = torch.load(filenames[0], map_location=device)["model"] + avg = torch.load(filenames[0], map_location=device, weights_only=False)["model"] # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -178,7 +177,9 @@ def average_checkpoints( uniqued_names = list(uniqued.values()) for i in range(1, n): - state_dict = torch.load(filenames[i], map_location=device)["model"] + state_dict = torch.load(filenames[i], map_location=device, weights_only=False)[ + "model" + ] for k in uniqued_names: avg[k] += state_dict[k] @@ -199,7 +200,7 @@ def save_checkpoint_with_global_batch_idx( params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, + scaler: Optional["GradScaler"] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ): @@ -421,8 +422,10 @@ def average_checkpoints_with_averaged_model( device: Move checkpoints to this device before averaging. """ - state_dict_start = torch.load(filename_start, map_location=device) - state_dict_end = torch.load(filename_end, map_location=device) + state_dict_start = torch.load( + filename_start, map_location=device, weights_only=False + ) + state_dict_end = torch.load(filename_end, map_location=device, weights_only=False) average_period = state_dict_start["average_period"] diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 37872f233..d923e8842 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -63,12 +63,22 @@ def get_tensor_stats( "rms" -> square before summing, we'll take sqrt later "value" -> just sum x itself "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing + "rms-sort" -> this is a bit different than the others, it's based on computing the + rms over the specified dim and returning percentiles of the result (11 of them). Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element of stats. """ + if stats_type == "rms-sort": + rms = (x**2).mean(dim=dim).sqrt() + rms = rms.flatten() + rms = rms.sort()[0] + rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] + count = 1.0 + return rms, count + count = x.numel() // x.shape[dim] if stats_type == "eigs": @@ -164,7 +174,17 @@ class TensorDiagnostic(object): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "max", "min", "positive", "value", "rms"] + # rms-sort is different from the others, it's based on summing over just this + # dim, then sorting and returning the percentiles. + stats_types = [ + "abs", + "max", + "min", + "positive", + "value", + "rms", + "rms-sort", + ] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: @@ -611,7 +631,10 @@ def attach_diagnostics( ) module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) if type(module).__name__ in [ "Sigmoid", @@ -645,7 +668,10 @@ def attach_diagnostics( _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) - module.register_backward_hook(scalar_backward_hook) + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(scalar_backward_hook) + else: + module.register_backward_hook(scalar_backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/hooks.py b/icefall/hooks.py index 1c5bd2ae6..b543190be 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -39,28 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None: # default param _name is a way to capture the current value of the variable "name". def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output is not finite: {_output}" - ) + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning(f"The sum of {_name}.output is not finite") + except RuntimeError: # e.g. CUDA out of memory + pass elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): o = o[0] if not isinstance(o, Tensor): continue - if not torch.isfinite(o.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output[{i}] is not finite: {_output}" - ) + try: + if not torch.isfinite(o.to(torch.float32).sum()): + logging.warning( + f"The sum of {_name}.output[{i}] is not finite" + ) + except RuntimeError: # e.g. CUDA out of memory + pass # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): - if not torch.isfinite(_output.to(torch.float32).sum()): - logging.warning( - f"The sum of {_name}.grad is not finite" # ": {_output}" - ) + try: + if not torch.isfinite(_output.to(torch.float32).sum()): + logging.warning(f"The sum of {_name}.grad is not finite") + except RuntimeError: # e.g. CUDA out of memory + pass + elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): @@ -71,7 +77,11 @@ def register_inf_check_hooks(model: nn.Module) -> None: logging.warning(f"The sum of {_name}.grad[{i}] is not finite") module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) + + if hasattr(module, "register_full_backward_hook"): + module.register_full_backward_hook(backward_hook) + else: + module.register_backward_hook(backward_hook) for name, parameter in model.named_parameters(): diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 22e1b78bb..6a157ffea 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -166,10 +166,10 @@ class Lexicon(object): if (lang_dir / "Linv.pt").exists(): logging.info(f"Loading pre-compiled {lang_dir}/Linv.pt") - L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt")) + L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt", weights_only=False)) else: logging.info("Converting L.pt to Linv.pt") - L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt")) + L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt", weights_only=False)) L_inv = k2.arc_sort(L.invert()) torch.save(L_inv.as_dict(), lang_dir / "Linv.pt") diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py index 53be53f64..f6faf8d71 100644 --- a/icefall/rnn_lm/dataset.py +++ b/icefall/rnn_lm/dataset.py @@ -180,7 +180,7 @@ def get_dataloader( Returns: Return a dataloader containing the LM data. """ - lm_data = torch.load(filename) + lm_data = torch.load(filename, weights_only=False) words = lm_data["words"] sentences = lm_data["sentences"] diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0178b80bf..023afb5a5 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -53,7 +53,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -401,7 +407,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +476,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py index 29a2cd7f7..6aae7cea8 100755 --- a/icefall/shared/convert-k2-to-openfst.py +++ b/icefall/shared/convert-k2-to-openfst.py @@ -80,7 +80,7 @@ def main(): assert Path(input_filename).is_file(), f"{input_filename} does not exist" logging.info(f"Loading {input_filename}") - k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + k2_fst = k2.Fsa.from_dict(torch.load(input_filename, weights_only=False)) if olabels: assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index c36abfcdf..acec95e94 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -50,7 +50,13 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + torch_autocast, +) def get_parser(): @@ -341,7 +347,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +409,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch_autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/utils.py b/icefall/utils.py index 41eebadd4..427755090 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -26,6 +26,7 @@ import pathlib import random import re import subprocess +import warnings from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -42,6 +43,7 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.signal_transforms import time_warp as time_warp_impl +from packaging import version from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -50,6 +52,48 @@ from icefall.checkpoint import average_checkpoints Pathlike = Union[str, Path] +TORCH_VERSION = version.parse(torch.__version__) + + +def create_grad_scaler(device="cuda", **kwargs): + """ + Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0. + Accepts all kwargs like: enabled, init_scale, growth_factor, etc. + + /icefall/egs/librispeech/ASR/./zipformer/train.py:1451: FutureWarning: + `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use + `torch.amp.GradScaler('cuda', args...)` instead. + """ + if TORCH_VERSION >= version.parse("2.3.0"): + from torch.amp import GradScaler + + return GradScaler(device=device, **kwargs) + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + return torch.cuda.amp.GradScaler(**kwargs) + + +@contextmanager +def torch_autocast(device_type="cuda", **kwargs): + """ + To fix the following warnings: + /icefall/egs/librispeech/ASR/zipformer/model.py:323: + FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. + Please use `torch.amp.autocast('cuda', args...)` instead. + with torch.cuda.amp.autocast(enabled=False): + """ + if TORCH_VERSION >= version.parse("2.3.0"): + # Use new unified API + with torch.amp.autocast(device_type=device_type, **kwargs): + yield + else: + # Suppress deprecation warning and use old CUDA-specific autocast + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + with torch.cuda.amp.autocast(**kwargs): + yield + # Pytorch issue: https://github.com/pytorch/pytorch/issues/47379 # Fixed: https://github.com/pytorch/pytorch/pull/49853 @@ -186,7 +230,7 @@ class AttributeDict(dict): tmp = {} for k, v in self.items(): # PosixPath is ont JSON serializable - if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + if isinstance(v, (pathlib.Path, torch.device, torch.dtype)): v = str(v) tmp[k] = v return json.dumps(tmp, indent=indent, sort_keys=True) @@ -505,7 +549,7 @@ def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: - alignments: A dict containing utterances and their corresponding framewise alignment, after subsampling. """ - ali_dict = torch.load(filename) + ali_dict = torch.load(filename, weights_only=False) subsampling_factor = ali_dict["subsampling_factor"] alignments = ali_dict["alignments"] return subsampling_factor, alignments @@ -1551,6 +1595,7 @@ def optim_step_and_measure_param_change( and the L2 norm of the original parameter. It is given by the formula: .. math:: + \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} @@ -1758,6 +1803,30 @@ def tokenize_by_CJK_char(line: str) -> str: return " ".join([w.strip() for w in chars if w.strip()]) +def tokenize_by_ja_char(line: str) -> str: + """ + Tokenize a line of text with Japanese characters. + + Note: All non-Japanese characters will be upper case. + + Example: + input = "こんにちは世界は hello world の日本語" + output = "こ ん に ち は 世 界 は HELLO WORLD の 日 本 語" + + Args: + line: + The input text. + + Return: + A new string tokenized by Japanese characters. + """ + pattern = re.compile(r"([\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF])") + chars = pattern.split(line.strip()) + return " ".join( + [w.strip().upper() if not pattern.match(w) else w for w in chars if w.strip()] + ) + + def display_and_save_batch( batch: dict, params: AttributeDict, diff --git a/requirements.txt b/requirements.txt index d97263142..885bf2fc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ flake8==5.0.4 # cantonese word segment support pycantonese==3.4.0 +packaging